Related
I have been doing a project on image matching, so I need to find correspondences between 2 images. To get descriptors, I will need a interpolate function. However, when I read about a equivalent function which is done in Tensorflow, I still don’t get how to implement tf.gather_nd(parmas, indices, barch_dims) in Pytorch. Especially when there is a argument: batch_dims. I have gone through stackoverflow and there is no perfect equivalence yet.
The referred interpolate function in Tensorflow is below and I have been trying to implement this in Pytorch Arguments' information is below:
inputs is a dense feature map[i] from a for loop of batch size, which means it is 3D[H, W, C](in pytorch is [C, H, W])
pos is a set of random point coordinate shapes like [[i, j], [i, j],...,[i, j]], so it is 2D when it goes in interpolate function(in pytorch is [[i,i,...,i], [j,j,...,j]])
and it then expands both of their dimensions when they get into this function
I just want a perfect implement of tf.gather_nd with argument batch_dims. Thank you!
And here's a simple example of using it:
pos = tf.ones((12, 2)) ## stands for a set of coordinates [[i, i,…, i], [j, j,…, j]]
inputs = tf.ones((4, 4, 128)) ## stands for [H, W, C] of dense feature map
outputs = interpolate(pos, inputs, batched=False)
print(outputs.get_shape()) # We get (12, 128) here
interpolate function (tf version):
def interpolate(pos, inputs, nd=True):
pos = tf.expand_dims(pos, 0)
inputs = tf.expand_dims(inputs, 0)
h = tf.shape(inputs)[1]
w = tf.shape(inputs)[2]
i = pos[:, :, 0]
j = pos[:, :, 1]
i_top_left = tf.clip_by_value(tf.cast(tf.math.floor(i), tf.int32), 0, h - 1)
j_top_left = tf.clip_by_value(tf.cast(tf.math.floor(j), tf.int32), 0, w - 1)
i_top_right = tf.clip_by_value(tf.cast(tf.math.floor(i), tf.int32), 0, h - 1)
j_top_right = tf.clip_by_value(tf.cast(tf.math.ceil(j), tf.int32), 0, w - 1)
i_bottom_left = tf.clip_by_value(tf.cast(tf.math.ceil(i), tf.int32), 0, h - 1)
j_bottom_left = tf.clip_by_value(tf.cast(tf.math.floor(j), tf.int32), 0, w - 1)
i_bottom_right = tf.clip_by_value(tf.cast(tf.math.ceil(i), tf.int32), 0, h - 1)
j_bottom_right = tf.clip_by_value(tf.cast(tf.math.ceil(j), tf.int32), 0, w - 1)
dist_i_top_left = i - tf.cast(i_top_left, tf.float32)
dist_j_top_left = j - tf.cast(j_top_left, tf.float32)
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
w_top_right = (1 - dist_i_top_left) * dist_j_top_left
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
w_bottom_right = dist_i_top_left * dist_j_top_left
if nd:
w_top_left = w_top_left[..., None]
w_top_right = w_top_right[..., None]
w_bottom_left = w_bottom_left[..., None]
w_bottom_right = w_bottom_right[..., None]
interpolated_val = (
w_top_left * tf.gather_nd(inputs, tf.stack([i_top_left, j_top_left], axis=-1), batch_dims=1) +
w_top_right * tf.gather_nd(inputs, tf.stack([i_top_right, j_top_right], axis=-1), batch_dims=1) +
w_bottom_left * tf.gather_nd(inputs, tf.stack([i_bottom_left, j_bottom_left], axis=-1), batch_dims=1) +
w_bottom_right * tf.gather_nd(inputs, tf.stack([i_bottom_right, j_bottom_right], axis=-1), batch_dims=1)
)
interpolated_val = tf.squeeze(interpolated_val, axis=0)
return interpolated_val
As far as I'm aware there is no directly equivalent of tf.gather_nd in PyTorch and implementing a generic version with batch_dims is not that simple. However, you likely don't need a generic version, and given the context of your interpolate function, a version for [C, H, W] would suffice.
At the beginning of interpolate you add a singular dimension to the front, which is the batch dimension. Setting batch_dims=1 in tf.gather_nd means there is one batch dimension at the beginning, therefore it applies it per batch, i.e. it indexes inputs[0] with pos[0] etc. There is no benefit of adding a singular batch dimension, because you could have just used the direct computation.
# Adding singular batch dimension
# Shape: [1, num_pos, 2]
pos = tf.expand_dims(pos, 0)
# Shape: [1, H, W, C]
inputs = tf.expand_dims(inputs, 0)
batched_result = tf.gather_nd(inputs, pos, batch_dims=1)
single_result = tf.gater_nd(inputs[0], pos[0])
# The first element in the batched result is the same as the single result
# Hence there is no benefit to adding a singular batch dimension.
tf.reduce_all(batched_result[0] == single_result) # => True
Single version
In PyTorch the implementation for [H, W, C] can be done with Python's indexing. While PyTorch usually uses [C, H, W] for images, it's only a matter of what dimension to index, but let's keep them the same as in TensorFlow for the sake of comparison. If you were to index them manually, you would do it as such: inputs[pos_h[0], pos_w[0]], inputs[pos_h[1], pos_w[1]] and so on. PyTorch allows you to do that automatically by providing the indices as lists: inputs[pos_h, pos_w], where pos_h and pos_w have the same length. All you need to do is split your pos into two separate tensors, one for the indices along the height dimension and the other along the width dimension, which you also did in the TensorFlow version.
inputs = torch.randn(4, 4, 128)
# Random positions 0-3, shape: [12, 2]
pos = torch.randint(4, (12, 2))
# Positions split by dimension
pos_h = pos[:, 0]
pos_w = pos[:, 1]
# Index the inputs with the indices per dimension
gathered = inputs[pos_h, pos_w]
# Verify that it's identical to TensorFlow's output
inputs_tf = tf.convert_to_tensor(inputs.numpy())
pos_tf = tf.convert_to_tensor(pos.numpy())
gathered_tf = tf.gather_nd(inputs_tf, pos_tf)
gathered_tf = torch.from_numpy(gathered_tf.numpy())
torch.equal(gathered_tf, gathered) # => True
If you want to apply it to a tensor of size [C, H, W] instead, you only need to change the dimensions you want to index:
# For [H, W, C]
gathered = inputs[pos_h, pos_w]
# For [C, H, W]
gathered = inputs[:, pos_h, pos_w]
Batched version
Making it a batched batched version (for [N, H, W, C] or [N, C, H, W]) is not that difficult, and using that is more appropriate, since you're dealing with batches anyway. The only tricky part is that each element in the batch should only be applied to the corresponding batch. For this the batch dimensions needs to be enumerated, which can be done with torch.arange. The batch enumeration is just the list with the batch indices, which will be combined with the pos_h and pos_w indices, resulting in inputs[0, pos_h[0, 0], pos_h[0, 0]], inputs[0, pos_h[0, 1], pos_h[0, 1]] ... inputs[1, pos_h[1, 0], pos_h[1, 0]] etc.
batch_size = 3
inputs = torch.randn(batch_size, 4, 4, 128)
# Random positions 0-3, different for each batch, shape: [3, 12, 2]
pos = torch.randint(4, (batch_size, 12, 2))
# Positions split by dimension
pos_h = pos[:, :, 0]
pos_w = pos[:, :, 1]
batch_enumeration = torch.arange(batch_size) # => [0, 1, 2]
# pos_h and pos_w have shape [3, 12], so the batch enumeration needs to be
# repeated 12 times per batch.
# Unsqueeze to get shape [3, 1], now the 1 could be repeated to 12, but
# broadcasting will do that automatically.
batch_enumeration = batch_enumeration.unsqueeze(1)
# Index the inputs with the indices per dimension
gathered = inputs[batch_enumeration, pos_h, pos_w]
# Again, verify that it's identical to TensorFlow's output
inputs_tf = tf.convert_to_tensor(inputs.numpy())
pos_tf = tf.convert_to_tensor(pos.numpy())
# This time with batch_dims=1
gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=1)
gathered_tf = torch.from_numpy(gathered_tf.numpy())
torch.equal(gathered_tf, gathered) # => True
Again, for [N, C, H, W], only the dimensions that are indexed need to be changed:
# For [N, H, W, C]
gathered = inputs[batch_enumeration, pos_h, pos_w]
# For [N, C, H, W]
gathered = inputs[batch_enumeration, :, pos_h, pos_w]
Just a little side note on the interpolate implementation, rounding the positions (floor and ceil respectively) doesn't make sense, because indices must be integers, so it has no effect, as long as your positions are actual indices. That also results in i_top_left and i_bottom_left being the same value, but even if they are to be rounded differently, they are always 1 position apart. Furthermore, i_top_left and i_top_right are literally the same. I don't think that this function produces a meaningful output. I don't know what you're trying to achieve, but if you're looking for image interpolation you could have a look at torch.nn.functional.interpolate.
This is just an extension of Michael Jungo's batched version answer when pos is 2D array instead of 1D array (excluding batch dimension).
bs = 2
H = 4
W = 6
C = 3
inputs = torch.randn(bs, H, W, C)
pos_h = torch.randint(H, (bs, H, W))
pos_w = torch.randint(W, (bs, H, W))
batch_enumeration = torch.arange(bs)
batch_enumeration = batch_enumeration.unsqueeze(1).unsqueeze(2)
inputs.shape
Out[34]: torch.Size([2, 4, 6, 3])
pos_h.shape
Out[35]: torch.Size([2, 4, 6])
pos_w.shape
Out[36]: torch.Size([2, 4, 6])
batch_enumeration.shape
Out[37]: torch.Size([2, 1, 1])
gathered = inputs[batch_enumeration, pos_h, pos_w]
For channel first, we also need to enumerate channels
inputs = torch.randn(bs, C, H, W)
pos_h = torch.randint(H, (bs, 1, H, W))
pos_w = torch.randint(W, (bs, 1, H, W))
batch_enumeration = torch.arange(bs)
batch_enumeration = batch_enumeration.unsqueeze(1).unsqueeze(2).unsqueeze(3)
channel_enumeration = torch.arange(C)
channel_enumeration = channel_enumeration.unsqueeze(0).unsqueeze(2).unsqueeze(3)
inputs.shape
Out[49]: torch.Size([2, 3, 4, 6])
pos_h.shape
Out[50]: torch.Size([2, 1, 4, 6])
pos_w.shape
Out[51]: torch.Size([2, 1, 4, 6])
batch_enumeration.shape
Out[52]: torch.Size([2, 1, 1, 1])
channel_enumeration.shape
Out[57]: torch.Size([1, 3, 1, 1])
gathered = inputs[batch_enumeration, channel_enumeration, pos_h, pos_w]
gathered.shape
Out[59]: torch.Size([2, 3, 4, 6])
Let's verify
inputs_np = inputs.numpy()
pos_h_np = pos_h.numpy()
pos_w_np = pos_w.numpy()
gathered_np = gathered.numpy()
pos_h_np[0,0,0,0]
Out[68]: 0
pos_w_np[0,0,0,0]
Out[69]: 3
inputs_np[0,:,0,3]
Out[71]: array([ 0.79122806, -2.190181 , -0.16741803], dtype=float32)
gathered_np[0,:,0,0]
Out[72]: array([ 0.79122806, -2.190181 , -0.16741803], dtype=float32)
pos_h_np[1,0,3,4]
Out[73]: 1
pos_w_np[1,0,3,4]
Out[74]: 2
inputs_np[1,:,1,2]
Out[75]: array([ 0.9282498 , -0.34945545, 0.9136222 ], dtype=float32)
gathered_np[1,:,3,4]
Out[77]: array([ 0.9282498 , -0.34945545, 0.9136222 ], dtype=float32)
I improved the answer from Michael Jungo's implementation. Now it supports arbitrary leading batch dimensions.
def gather_nd_torch(params, indices, batch_dim=1):
""" A PyTorch porting of tensorflow.gather_nd
This implementation can handle leading batch dimensions in params, see below for detailed explanation.
The majority of this implementation is from Michael Jungo # https://stackoverflow.com/a/61810047/6670143
I just ported it compatible to leading batch dimension.
Args:
params: a tensor of dimension [b1, ..., bn, g1, ..., gm, c].
indices: a tensor of dimension [b1, ..., bn, x, m]
batch_dim: indicate how many batch dimension you have, in the above example, batch_dim = n.
Returns:
gathered: a tensor of dimension [b1, ..., bn, x, c].
Example:
>>> batch_size = 5
>>> inputs = torch.randn(batch_size, batch_size, batch_size, 4, 4, 4, 32)
>>> pos = torch.randint(4, (batch_size, batch_size, batch_size, 12, 3))
>>> gathered = gather_nd_torch(inputs, pos, batch_dim=3)
>>> gathered.shape
torch.Size([5, 5, 5, 12, 32])
>>> inputs_tf = tf.convert_to_tensor(inputs.numpy())
>>> pos_tf = tf.convert_to_tensor(pos.numpy())
>>> gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=3)
>>> gathered_tf.shape
TensorShape([5, 5, 5, 12, 32])
>>> gathered_tf = torch.from_numpy(gathered_tf.numpy())
>>> torch.equal(gathered_tf, gathered)
True
"""
batch_dims = params.size()[:batch_dim] # [b1, ..., bn]
batch_size = np.cumprod(list(batch_dims))[-1] # b1 * ... * bn
c_dim = params.size()[-1] # c
grid_dims = params.size()[batch_dim:-1] # [g1, ..., gm]
n_indices = indices.size(-2) # x
n_pos = indices.size(-1) # m
# reshape leadning batch dims to a single batch dim
params = params.reshape(batch_size, *grid_dims, c_dim)
indices = indices.reshape(batch_size, n_indices, n_pos)
# build gather indices
# gather for each of the data point in this "batch"
batch_enumeration = torch.arange(batch_size).unsqueeze(1)
gather_dims = [indices[:, :, i] for i in range(len(grid_dims))]
gather_dims.insert(0, batch_enumeration)
gathered = params[gather_dims]
# reshape back to the shape with leading batch dims
gathered = gathered.reshape(*batch_dims, n_indices, c_dim)
return gathered
I have also made a demo Colab notebook, you can check it here. This implementation is way faster than TF's original implementation according to my poor speed test on Colab server with a GPU instance.
I am doing the image semantic segmentation job with unet. I am confused with the last layers for pixel classification. The Unet code is like this:
...
reshape = Reshape((n_classes,self.img_rows * self.img_cols))(conv9)
permute = Permute((2,1))(reshape)
activation = Activation('softmax')(permute)
model = Model(input = inputs, output = activation)
return model
...
Can I just reshape without using Permute like this?
reshape = Reshape((self.img_rows * self.img_cols, n_classes))(conv9)
Updated:
I found the training result is not right when when using the directly reshape way:
reshape = Reshape((self.img_rows * self.img_cols, n_classes))(conv9) // the loss is not convergent
My groundtruth is generated like this:
X = []
Y = []
im = cv2.imread(impath)
X.append(im)
seg_labels = np.zeros((height, width, n_classes))
for spath in segpaths:
mask = cv2.imread(spath, 0)
seg_labels[:, :, c] += mask
Y.append(seg_labels.reshape(width*height, n_classes))
Why reshape directly does not work?
You clearly misunderstand the meaning of each operation and the final goal:
final goal: classification for each pixel, i.e. softmax along the semantic class axis
how to achieve this goal in the original code? Let's see the code line by line:
reshape = Reshape((n_classes,self.img_rows * self.img_cols))(conv9) # L1
permute = Permute((2,1))(reshape) # L2
activation = Activation('softmax')(permute) # L3
L1's output dim = n_class-by-n_pixs, (n_pixs=img_rows x img_cols)
L2's output dim = n_pixs-by-n_class
L3's output dim = n_pixs-by-n_class
Note the default softmax activation is applied to the last axis, i.e. the axis that n_class stands for, which is the semantic class axis.
Therefore, this original code fulfills the final goal of semantic segmentation.
Let's revisit the code that you want to change, which is
reshape = Reshape((self.img_rows * self.img_cols, n_classes))(conv9) # L4
L4's output dim = n_pixs-by-n_class
My guess is that you think L4's output dim matches L2's, and thus L4 is a short-cut that is equivalent to executing L1 and L2.
However, matching the shape does not necessarily mean matching the physical meaning of axes. Why? A simple example will explain.
Say you have 2 semantic classes and 3 pixels. To see the difference assume all three pixels belong to the same class.
In other words, a ground truth tensor will look like this
# cls#1 cls#2
[ [0, 1], # pixel #1
[0, 1], # pixel #2
[0, 1], # pixel #3
]
Assume you have a perfect network and generate the exact response for each pixel, but your solution will create a tensor like below
# cls#1 cls#2
[ [0, 0], # pixel #1
[0, 1], # pixel #2
[1, 1], # pixel #3
]
whose shape is the same as the ground truth's, but fails to match the physical meaning of axes.
This further makes the softmax operation meaningless, because it is supposed to apply to the class dimension, but this dimension does not physically exist. As a result, it leads to the following erroneous output after applying softmax,
# cls#1 cls#2
[ [0.5, 0.5], # pixel #1
[0, 1], # pixel #2
[0.5, 0.5], # pixel #3
]
which completely mess up the training even if it is under the ideal assumption.
Therefore, it is a good habit to write down the physical meaning of each axis of a tensor. When you do any tensor reshape operation, ask yourself whether the physical meaning of an axis is changed in your expected way.
For example, if you have a tensor T of shape batch_dim x img_rows x img_cols x feat_dim, you can do many things and not all of them make sense (due to the problematic physical meaning of axes)
(Wrong) reshape it to whatever x feat_dim, because whatever dimension is meaningless in testing where the batch_size might be different.
(Wrong) reshape it to batch_dim x feat_dim x img_rows x img_cols, because the 2nd dimension is NOT the feature dimension and neither for the 3rd and 4th dimension.
(Correct) permute axes (3,1,2), and this will lead you the tensor of shape batch_dim x feat_dim x img_rows x img_cols, while keeping the physical meaning of each axis.
(Correct) reshape it to batch_dim x whatever x feat_dim. This is also valid, because the whatever=img_rows x img_cols is equivalent to the pixel location dimension, and both the meanings of batch_dim and feat_dim are unchanged.
Your code will still be runnable since the shape will be the same, but the result (backprops) will be different since the values of tensors will be different. For example:
arr = np.array([[[1,1,1],[1,1,1]],[[2,2,2],[2,2,2]],[[3,3,3],[3,3,3]],[[4,4,4],[4,4,4]]])
arr.shape
>>>(4, 2, 3)
#do reshape, then premute
reshape_1 = arr.reshape((4, 2*3))
np.swapaxes(reshape_1, 1, 0)
>>>array([[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]])
#do reshape directly
reshape_2 = arr.reshape(2*3, 4)
reshape_2
>>>array([[1, 1, 1, 1],
[1, 1, 2, 2],
[2, 2, 2, 2],
[3, 3, 3, 3],
[3, 3, 4, 4],
[4, 4, 4, 4]])
The Reshape and Permute is done to take the softmax at each pixel location. Adding to #meowongac's answer, Reshape preserves the order of the elements. In this case, since the channel dimensions have to be swapped, Reshape followed by Permute is appropriate.
Considering the case of (2,2) image with 3 values at each location,
arr = np.array([[[1,1],[1,1]],[[2,2],[2,2]],[[3,3],[3,3]]])
>>> arr.shape
(3, 2, 2)
>>> arr
array([[[1, 1],
[1, 1]],
[[2, 2],
[2, 2]],
[[3, 3],
[3, 3]]])
>>> arr[:,0,0]
array([1, 2, 3])
The channel values at each location are [1,2,3]. The goal is to swap the channel axis(length 3) to the end.
>>> arr.reshape((2,2,3))[0,0]
array([1, 1, 1]) # incorrect
>>> arr.transpose((1,2,0))[0,0] # similar to what permute does.
array([1, 2, 3]) # correct
More examples at this link: https://discuss.pytorch.org/t/how-to-change-shape-of-a-matrix-without-dispositioning-the-elements/30708
From the accepted answer in this question,
given the following
input and kernel matrices, the output of tf.nn.conv2d is
[[14 6]
[6 12]]
which makes sense. However, when I make the input and kernel matrices have 3-channels each (by repeating each original matrix), and run the same code:
# the previous input
i_grey = np.array([
[4, 3, 1, 0],
[2, 1, 0, 1],
[1, 2, 4, 1],
[3, 1, 0, 2]
])
# copy to 3-dimensions
i_rgb = np.repeat( np.expand_dims(i_grey, axis=0), 3, axis=0 )
# convert to tensor
i_rgb = tf.constant(i_rgb, dtype=tf.float32)
# make kernel depth match input; same process as input
k = np.array([
[1, 0, 1],
[2, 1, 0],
[0, 0, 1]
])
k_rgb = np.repeat( np.expand_dims(k, axis=0), 3, axis=0 )
# convert to tensor
k_rgb = tf.constant(k_rgb, dtype=tf.float32)
here's what my input and kernel matrices look like at this point
# reshape input to format: [batch, in_height, in_width, in_channels]
image_rgb = tf.reshape(i_rgb, [1, 4, 4, 3])
# reshape kernel to format: [filter_height, filter_width, in_channels, out_channels]
kernel_rgb = tf.reshape(k_rgb, [3, 3, 3, 1])
conv_rgb = tf.squeeze( tf.nn.conv2d(image_rgb, kernel_rgb, [1,1,1,1], "VALID") )
with tf.Session() as sess:
conv_result = sess.run(conv_rgb)
print(conv_result)
I get the final output:
[[35. 15.]
[35. 26.]]
But I was expecting the original output*3:
[[42. 18.]
[18. 36.]]
because from my understanding, each channel of the kernel is convolved with each channel of the input, and the resultant matrices are summed to get the final output.
Am I missing something from this process or the tensorflow implementation?
Reshape is a tricky function. It will produce you the shape you want, but can easily ground things together. In cases like yours, one should avoid using reshape by all means.
In that particular case instead, it is better to duplicate the arrays along the new axis. When using [batch, in_height, in_width, in_channels] channels is the last dimension and it should be used in repeat() function. Next code should better reflect the logic behind it:
i_grey = np.expand_dims(i_grey, axis=0) # add batch dim
i_grey = np.expand_dims(i_grey, axis=3) # add channel dim
i_rgb = np.repeat(i_grey, 3, axis=3 ) # duplicate along channels dim
And likewise with filters:
k = np.expand_dims(k, axis=2) # input channels dim
k = np.expand_dims(k, axis=3) # output channels dim
k_rgb = np.repeat(k, 3, axis=2) # duplicate along the input channels dim
I'm trying to build a neural network where the labels and the number of labels change on input. For example, I could have a final layer of 10 units that represent the logit of their class, but sometimes I will only need units [1,3,4] to calculate cross entropy, some of the units [3,4,5,7] etc.
I tried using different combinations of map_fn, gather, py_fn and while_loop but no one seems to be in my case. Another way might be to list all types of label combinations (I call them network heads) and find some conditional constructs that allow me to choose one based on the value of a placeholder. But I'm not sure how to implement it.
For example:
x = tf.placeholder(dtype=tf.float32, shape=[None,3])
y = tf.placeholder(dtype=tf.int32, shape=[None, 3])
... to_do ...
with tf.Session() as sess:
sess.run(to_do, feed_dict={x: [[1, 3, 4], [3, 7, 8]], y: [[1, 0, 0], [0, 1, 1]]})
Here I need something that return [[1],[7,8]].
Oh no matter. There was a very easy way to get the probabilites I needed for cross-entropy.
x = tf.placeholder(dtype=tf.float32, shape=[None,3])
y = tf.placeholder(dtype=tf.int32, shape=[None, 3])
probabilities = tf.where(tf.equal(y,1), tf.exp(x), tf.zeros_like(x))
normalizing_sum = tf.reduce_sum(probabilities, 1, keep_dims=True)
probabilities/=normalizing_sum
with tf.Session() as sess:
res = sess.run(probabilities, feed_dict={x: [[1, 3, 4], [3, 7, 8]], y: [[1, 0, 0], [0, 1, 1]]})
I have an input to tensorflow of shape [None, 9, 2] (where the None is batch).
To perform further actions (e.g. matmul) on it I need to transform it to [None, 18] shape. How to do it?
You can do it easily with tf.reshape() without knowing the batch size.
x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list() # a list: [None, 9, 2]
dim = numpy.prod(shape[1:]) # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim]) # -1 means "all"
The -1 in the last line means the whole column no matter what the batchsize is in the runtime. You can see it in tf.reshape().
Update: shape = [None, 3, None]
Thanks #kbrose. For the cases where more than 1 dimension are undefined, we can use tf.shape() with tf.reduce_prod() alternatively.
x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])
tf.shape() returns a shape Tensor which can be evaluated in runtime. The difference between tf.get_shape() and tf.shape() can be seen in the doc.
I also tried tf.contrib.layers.flatten() in another . It is simplest for the first case, but it can't handle the second.
flat_inputs = tf.layers.flatten(inputs)
You can use dynamic reshaping to get value of batch dimension through tf.batch during runtime, calculate the whole set of new dimensions into tf.reshape. Here's an example of reshaping flat list into square matrix without knowing list length.
tf.reset_default_graph()
sess = tf.InteractiveSession("")
a = tf.placeholder(dtype=tf.int32)
# get [9]
ashape = tf.shape(a)
# slice the list from 0th to 1st position
ashape0 = tf.slice(ashape, [0], [1])
# reshape list to scalar, ie from [9] to 9
ashape0_flat = tf.reshape(ashape0, ())
# tf.sqrt doesn't support int, so cast to float
ashape0_flat_float = tf.to_float(ashape0_flat)
newshape0 = tf.sqrt(ashape0_flat_float)
# convert [3, 3] Python list into [3, 3] Tensor
newshape = tf.pack([newshape0, newshape0])
# tf.reshape doesn't accept float, so convert back to int
newshape_int = tf.to_int32(newshape)
a_reshaped = tf.reshape(a, newshape_int)
sess.run(a_reshaped, feed_dict={a: np.ones((9))})
You should see
array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]], dtype=int32)