Perform a coordinate transformation of a 4th-order tensor with np.einsum and np.tensordot - numpy

The equation is
$C'_{ijkl} = Q_{im} Q_{jn} C_{mnop} (Q^{-1})_{ok} (Q^{-1})_{pl}$
I was able to use
np.einsum('im,jn,mnop,ok,pl', Q, Q, C, Q_inv, Q_inv)
to do the job, and also expect
np.tensordot(np.tensordot(np.tensordot(Q, np.tensordot(Q, C, axes=[1,1]), axes=[1,0]), Q_inv, axes=[2,0]), Q_inv, axes=[3,0])
to work, but it doesn't.
Specifics:
C is a 4th-order elastic tensor:
array([[[[ 552.62389047, -0.28689554, -0.32194701],
[ -0.28689554, 118.89168597, -0.65559912],
[ -0.32194701, -0.65559912, 130.21758722]],
[[ -0.28689554, 166.02923119, -0.00000123],
[ 166.02923119, 0.49494431, -0.00000127],
[ -0.00000123, -0.00000127, -0.57156702]],
[[ -0.32194701, -0.00000123, 165.99413061],
[ -0.00000123, -0.64666809, -0.0000013 ],
[ 165.99413061, -0.0000013 , 0.42997465]]],
[[[ -0.28689554, 166.02923119, -0.00000123],
[ 166.02923119, 0.49494431, -0.00000127],
[ -0.00000123, -0.00000127, -0.57156702]],
[[ 118.89168597, 0.49494431, -0.64666809],
[ 0.49494431, 516.15898907, -0.33132485],
[ -0.64666809, -0.33132485, 140.09010389]],
[[ -0.65559912, -0.00000127, -0.0000013 ],
[ -0.00000127, -0.33132485, 165.98553869],
[ -0.0000013 , 165.98553869, 0.41913346]]],
[[[ -0.32194701, -0.00000123, 165.99413061],
[ -0.00000123, -0.64666809, -0.0000013 ],
[ 165.99413061, -0.0000013 , 0.42997465]],
[[ -0.65559912, -0.00000127, -0.0000013 ],
[ -0.00000127, -0.33132485, 165.98553869],
[ -0.0000013 , 165.98553869, 0.41913346]],
[[ 130.21758722, -0.57156702, 0.42997465],
[ -0.57156702, 140.09010389, 0.41913346],
[ 0.42997465, 0.41913346, 486.62412063]]]])
Q is a rotation matrix changing x and y coords.
array([[ 0, 1, 0],
[-1, 0, 0],
[ 0, 0, 1]])
Q_inv is
array([[-0., -1., -0.],
[ 1., 0., 0.],
[ 0., 0., 1.]])
np.einsum leads to
array([[[[ 516.15898907, -0.49494431, -0.33132485],
[ -0.49494431, 118.89168597, 0.64666809],
[ -0.33132485, 0.64666809, 140.09010389]],
[[ -0.49494431, 166.02923119, 0.00000127],
[ 166.02923119, 0.28689554, -0.00000123],
[ 0.00000127, -0.00000123, 0.57156702]],
[[ -0.33132485, 0.00000127, 165.98553869],
[ 0.00000127, -0.65559912, 0.0000013 ],
[ 165.98553869, 0.0000013 , 0.41913346]]],
[[[ -0.49494431, 166.02923119, 0.00000127],
[ 166.02923119, 0.28689554, -0.00000123],
[ 0.00000127, -0.00000123, 0.57156702]],
[[ 118.89168597, 0.28689554, -0.65559912],
[ 0.28689554, 552.62389047, 0.32194701],
[ -0.65559912, 0.32194701, 130.21758722]],
[[ 0.64666809, -0.00000123, 0.0000013 ],
[ -0.00000123, 0.32194701, 165.99413061],
[ 0.0000013 , 165.99413061, -0.42997465]]],
[[[ -0.33132485, 0.00000127, 165.98553869],
[ 0.00000127, -0.65559912, 0.0000013 ],
[ 165.98553869, 0.0000013 , 0.41913346]],
[[ 0.64666809, -0.00000123, 0.0000013 ],
[ -0.00000123, 0.32194701, 165.99413061],
[ 0.0000013 , 165.99413061, -0.42997465]],
[[ 140.09010389, 0.57156702, 0.41913346],
[ 0.57156702, 130.21758722, -0.42997465],
[ 0.41913346, -0.42997465, 486.62412063]]]])
which I believe is correct, while four np.tensordot leads to
array([[[[ 552.62389047, -0.28689554, 0.32194701],
[ -0.28689554, 118.89168597, 0.65559912],
[ -0.32194701, -0.65559912, -130.21758722]],
[[ -0.28689554, 166.02923119, 0.00000123],
[ 166.02923119, 0.49494431, 0.00000127],
[ -0.00000123, -0.00000127, 0.57156702]],
[[ -0.32194701, -0.00000123, -165.99413061],
[ -0.00000123, -0.64666809, 0.0000013 ],
[ 165.99413061, -0.0000013 , -0.42997465]]],
[[[ -0.28689554, 166.02923119, 0.00000123],
[ 166.02923119, 0.49494431, 0.00000127],
[ -0.00000123, -0.00000127, 0.57156702]],
[[ 118.89168597, 0.49494431, 0.64666809],
[ 0.49494431, 516.15898907, 0.33132485],
[ -0.64666809, -0.33132485, -140.09010389]],
[[ -0.65559912, -0.00000127, 0.0000013 ],
[ -0.00000127, -0.33132485, -165.98553869],
[ -0.0000013 , 165.98553869, -0.41913346]]],
[[[ 0.32194701, 0.00000123, 165.99413061],
[ 0.00000123, 0.64666809, -0.0000013 ],
[-165.99413061, 0.0000013 , 0.42997465]],
[[ 0.65559912, 0.00000127, -0.0000013 ],
[ 0.00000127, 0.33132485, 165.98553869],
[ 0.0000013 , -165.98553869, 0.41913346]],
[[-130.21758722, 0.57156702, 0.42997465],
[ 0.57156702, -140.09010389, 0.41913346],
[ -0.42997465, -0.41913346, 486.62412063]]]])
Notice the negative big numbers.

Approach #1
One way would be to use np.tensordot to get the same result as with np.einsum though not in a single step and with some help from the trusty broadcasting -
# Get broadcasted elementwise multiplication between two versions of Q.
# This corresponds to "np.einsum('im,jn,..', Q, Q)" producing "'ijmn""
# broadcasted version of elementwise multiplications between Q's.
Q_ext = Q[:,None,:,None]*Q[:,None,:]
# Similarly for Q_inv : For "np.einsum('..ok,pl', Q_inv, Q_inv)" get "'opkl'"
# broadcasted version of elementwise multiplications between Q_inv's.
Q_inv_ext = Q_inv[:,None,:,None]*Q_inv[:,None,:]
# Perform "np.einsum('im,jn,mnop,ok,pl', Q, Q, C)" with "np.tensordot".
# Notice that we are using the last two axes from 'Q_ext', so "axes=[2,3]"
# and first two from 'C', so "axes=[0,1]" for it.
# These axes would be reduced by the dot-product, leaving us with 'ijop'.
parte1 = np.tensordot(Q_ext,C,axes=([2,3],[0,1]))
# Do it one more time to perform "np.einsum('ijop,ok,pl', parte1,Q_inv,Q_inv)"
# to reduce dimensions represented by 'o,p', leaving us with 'ijkl'.
# To confirm, compare the following against original einsum approach :
# "np.einsum('im,jn,mnop,ok,pl->ijkl', Q, Q, C, Q_inv, Q_inv)"
out = np.tensordot(parte1,Q_inv_ext,axes=([2,3],[0,1]))
Approach #2
If you wish to avoid broadcasting in favour of using two more instances of np.tensordot, you could do -
# Perform "np.einsum('jn,mnop', Q, C). Notice how, Q is represented by 'jn'
# and C by 'mnop'. We need to reduce the 'm' dimension, i.e. reduce 'axes=1'
# from Q and `axes=1` from C corresponding to `n' in each of the inputs.
# Thus, 'jn' + 'mnop' => 'jmop' after 'n' is reduced and order is maintained.
Q_C1 = np.tensordot(Q,C,axes=([1],[1]))
# Perform "np.einsum('im,jn,mnop', Q, Q, C). We need to use Q and Q_C1.
# Q is 'im' and Q_C1 is 'jmop'. Thus, again we need to reduce 'axes=1'
# from Q and `axes=1` from Q_C1 corresponding to `m' in each of the inputs.
# Thus, 'im' + 'jmop' => 'ijop' after 'm' is reduced and order is maintained.
parte1 = np.tensordot(Q,Q_C1,axes=([1],[1]))
# Use the same philosophy to get the rest of the einsum equivalent,
# but use parte1 and go right and use Q_inv
out = np.tensordot(np.tensordot(parte1,Q_inv,axes=([2],[0])),Q_inv,axes=([2],[0]))
The trick with np.tensordot is to keep track of the dimensions that are reduced by the axes parameter and how the collapsed dimensions align against the remaining inputs' dimensions.

Related

Taking fft / ifft of a stereo signal in numpy?

This question is related to both Apply FFT to a both channels of a stereo signal separately? and How to represent stereo audio data for FFT, but specifically for numpy's fft package.
How do I take the FFT of a (real-valued) FFT in numpy, and how to I get it back to the time domain?
If your stereo data is in two columns (i.e. left channel in column 0 and right channel in column 1), you can do it in a single operation - you only need to transpose the data first. To demonstrate:
Here are two channels of data, eight samples long. The left is a sine wave at f1 (it completes one cycle in the eight samples), the right is a sine wave at f2 (it completes two cycles):
s = array([[ 0.14285714, 0.14285714],
[ 0.12870984, 0.08906997],
[ 0.08906997, -0.0317887 ],
[ 0.0317887 , -0.12870984],
[-0.0317887 , -0.12870984],
[-0.08906997, -0.0317887 ],
[-0.12870984, 0.08906997],
[-0.14285714, 0.14285714],
[-0.12870984, 0.08906997],
[-0.08906997, -0.0317887 ],
[-0.0317887 , -0.12870984],
[ 0.0317887 , -0.12870984],
[ 0.08906997, -0.0317887 ],
[ 0.12870984, 0.08906997]])
If you transpose it (so left channel is row 0 and right channel is row 1), you can then pass it directly to np.fft.rfft() for conversions:
>>> s_t = s.transpose()
>>> s_t
array([[ 0.14285714, 0.12870984, 0.08906997, 0.0317887 , -0.0317887 ,
-0.08906997, -0.12870984, -0.14285714, -0.12870984, -0.08906997,
-0.0317887 , 0.0317887 , 0.08906997, 0.12870984],
[ 0.14285714, 0.08906997, -0.0317887 , -0.12870984, -0.12870984,
-0.0317887 , 0.08906997, 0.14285714, 0.08906997, -0.0317887 ,
-0.12870984, -0.12870984, -0.0317887 , 0.08906997]])
>>> f = np.fft.rfft(s_t)
>>> np.set_printoptions(suppress=True) # make it easier to read
>>> f
array([[ 0.+0.j, 1.+0.j, 0.+0.j, -0.-0.j, 0.-0.j, -0.+0.j, 0.+0.j, 0.+0.j],
[-0.+0.j, 0.+0.j, 1.+0.j, -0.-0.j, 0.-0.j, 0.+0.j, -0.+0.j, 0.+0.j]])
>>>
You can see from above that the left channel (row 0) has a '1' in bin 1 and the right channel (row 1) has a '1' in bin 2, which is what we'd expect. If you want your frequency data to be in column format, of course you can transpose that. And if you want just the real components, you can do that at the same time:
>>> f.transpose().real
array([[ 0., -0.],
[ 1., 0.],
[ 0., 1.],
[-0., -0.],
[ 0., 0.],
[-0., 0.],
[ 0., -0.],
[ 0., 0.]])
To prove that this is a proper transform of our original stereo data, compare this to s (above):
>>> np.fft.irfft(f).transpose().real
array([[ 0.14285714, 0.14285714],
[ 0.12870984, 0.08906997],
[ 0.08906997, -0.0317887 ],
[ 0.0317887 , -0.12870984],
[-0.0317887 , -0.12870984],
[-0.08906997, -0.0317887 ],
[-0.12870984, 0.08906997],
[-0.14285714, 0.14285714],
[-0.12870984, 0.08906997],
[-0.08906997, -0.0317887 ],
[-0.0317887 , -0.12870984],
[ 0.0317887 , -0.12870984],
[ 0.08906997, -0.0317887 ],
[ 0.12870984, 0.08906997]])

TypeError: '>=' not supported between instances of 'dict' and 'dict'

I meet this question on np.max() fuction, my code is
err_qt = np.asarray(err_qt)
err_qt = np.max(err_qt, axis=1)
err_qt = err_qt * 180.0 / np.pi`
The err_qt is a matrix and I print it.
[[{'K': array([[682.10083008, 0. , 239.82280391],
[ 0. , 683.28820801, 427.01684648],
[ 0. , 0. , 1. ]]), 'R': array([[-9.93854702e-01, -1.10692268e-01, -2.28596966e-04],
[ 1.16574434e-03, -8.40159629e-03, -9.99964026e-01],
[ 1.10686366e-01, -9.93819216e-01, 8.47900486e-03]]), 'T': array([-15.13970906, 1.19332025, 16.90097562]), 'q': array([0., 0., 0., 0.])}
{'K': array([[682.10083008, 0. , 239.82280391],
[ 0. , 683.28820801, 427.01684648],
[ 0. , 0. , 1. ]]), 'R': array([[-0.88778025, -0.45392103, -0.0761704 ],
[-0.04630574, 0.2527365 , -0.96642643],
[ 0.45793232, -0.85444716, -0.24539362]]), 'T': array([-5.09243763, -4.6725806 , 24.14722353]), 'q': array([0., 0., 0., 0.])}]
[{'K': array([[682.10083008, 0. , 239.82280391],
[ 0. , 683.28820801, 427.01684648],
[ 0. , 0. , 1. ]]), 'R': array([[-9.93854702e-01, -1.10692268e-01, -2.28596966e-04],
[ 1.16574434e-03, -8.40159629e-03, -9.99964026e-01],
[ 1.10686366e-01, -9.93819216e-01, 8.47900486e-03]]), 'T': array([-15.13970906, 1.19332025, 16.90097562]), 'q': array([0., 0., 0., 0.])}
...
How can I solve it?

Pytorch's gather, sequeeze and unsqueeze to Tensorflow Keras

I am migrating a code from pytorch to tensorflow, and in the function that calculates the loss, I have the below line that I need to migrate to tensorflow.
state_action_values = net(t_states_features).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
I found tf.gather and tf.gather_nd and I am not sure which is more suitable and how it could be used, also unsqueeze's alternative is maybe tf.expand_dims?
In an attempt to get a clearer view of the line's result, I split it into multiple parts with print statements.
print("net result")
state_action_values = net(t_states_features)
print(state_action_values)
print("gather result")
state_action_values = state_action_values.gather(1, actions_v.unsqueeze(-1))
print(state_action_values)
print("last squeeze")
state_action_values = state_action_values.squeeze(-1)
net result
tensor([[ 45.6878, -14.9495, 59.3737],
[ 33.5737, -10.4617, 39.0078],
[ 67.7197, -22.8818, 85.7977],
[ 94.7701, -33.2053, 120.5519],
[ nan, nan, nan],
[ 84.7324, -29.2101, 108.0821],
[ 67.7193, -22.7702, 86.9558],
[113.6835, -38.7149, 142.6167],
[ 61.9260, -20.1968, 79.8010],
[ 51.6152, -17.7391, 66.0719],
[ 73.6565, -21.5699, 98.9463],
[ 84.0761, -26.5016, 107.6888],
[ 60.9459, -20.1257, 76.4105],
[103.2883, -35.4035, 130.4503],
[ 37.1156, -13.5180, 47.1067],
[ nan, nan, nan],
[ 55.6286, -18.5239, 71.9837],
[ 55.3858, -18.7892, 71.1197],
[ 50.2419, -17.2959, 66.7059],
[ 82.5715, -30.0302, 108.4984],
[ -0.8662, -1.1861, 1.6033],
[112.4620, -38.6416, 142.4556],
[ 57.8702, -19.8080, 74.7656],
[ 45.8418, -15.7436, 57.3367],
[ 81.6596, -27.5002, 104.6002],
[ 57.1507, -21.8001, 67.7933],
[ 35.0414, -11.8199, 47.6573],
[ 67.7085, -23.1017, 85.4623],
[ 40.6284, -12.4578, 58.9603],
[ 68.6394, -23.1481, 87.0832],
[ 27.0549, -8.6635, 34.0150],
[ 25.4071, -8.5511, 34.0285],
[ 62.9161, -22.1693, 78.7965],
[ 85.4505, -28.1487, 108.6252],
[ 67.6665, -23.2376, 85.7117],
[ 60.7806, -20.2784, 77.1022],
[ 66.5209, -21.5674, 88.5561],
[ 61.6637, -20.9891, 72.3873],
[ 45.1634, -15.4678, 61.4886],
[ 66.8119, -23.1250, 85.6189],
[ nan, nan, nan],
[ 67.8166, -24.8342, 84.6706],
[ 86.2114, -29.5941, 107.8025],
[ 66.2716, -23.3309, 83.9700],
[101.2122, -35.3554, 127.4772],
[ 61.0749, -19.4720, 78.5588],
[ 50.4058, -16.1262, 63.1010],
[ 27.7543, -9.3767, 35.7448],
[ 67.7810, -23.4962, 83.6030],
[ 35.0103, -11.7238, 44.7983],
[ 55.7402, -19.0223, 70.3627],
[ 67.9733, -22.0783, 85.1893],
[ 60.5253, -20.3157, 79.7312],
[ 67.2404, -21.5205, 81.4499],
[ 57.9502, -20.7747, 70.9109],
[ 87.6536, -31.4256, 112.6491],
[ 90.3668, -30.7755, 116.6192],
[ 59.0660, -19.6988, 75.0723],
[ 50.0969, -17.4135, 62.6556],
[ 28.8703, -9.0950, 34.5749],
[ 68.4053, -22.0715, 88.2302],
[ 69.1397, -21.4236, 84.7833],
[ 23.8506, -8.1834, 30.8318],
[ 58.4296, -20.2432, 73.8116],
[ 87.5317, -29.0606, 110.0389],
[ nan, nan, nan],
[ 88.6387, -30.6154, 112.4239],
[ 51.6089, -16.1073, 66.2757],
[ 94.3989, -32.1473, 119.0358],
[ 82.7449, -30.7778, 102.8537],
[ 74.3067, -26.6585, 98.2536],
[ 77.0881, -26.5706, 98.3553],
[ 28.5688, -9.2949, 41.1165],
[ 86.1560, -26.9364, 107.0244],
[ 41.8914, -16.9703, 57.3840],
[ 88.8886, -29.7008, 108.2697],
[ 61.1243, -20.7566, 77.2257],
[ 85.1174, -28.7558, 107.3853],
[ 81.7256, -27.9047, 104.5006],
[ 51.2663, -16.5880, 67.1428],
[ 46.9150, -12.7457, 61.3240],
[ 36.1758, -12.9769, 47.7178],
[ 85.5846, -29.4141, 107.9649],
[ 59.9424, -20.8349, 75.3359],
[ 62.6516, -22.1235, 81.6903],
[104.7664, -34.5876, 129.9478],
[ 64.4671, -23.3980, 83.9093],
[ 69.6928, -23.6567, 89.6024],
[ 60.4407, -19.6136, 75.9350],
[ 33.4921, -10.3434, 44.9537],
[ 57.9112, -19.4174, 74.3050],
[ 24.8262, -9.3637, 30.1057],
[ 85.3776, -28.9097, 110.1310],
[ 63.8175, -22.3843, 81.0308],
[ 34.6040, -12.3217, 46.0356],
[ 88.3740, -29.5049, 110.2897],
[ 66.8196, -22.5860, 85.5386],
[ 58.9767, -22.0601, 78.7086],
[ 83.2090, -26.3499, 113.5105],
[ 54.8450, -17.7980, 68.1161],
[ nan, nan, nan],
[ 85.0846, -29.2494, 107.6780],
[ 76.9251, -26.2295, 98.4755],
[ 98.2907, -32.8878, 124.9192],
[ 91.1387, -30.8262, 115.3978],
[ 73.1062, -24.9450, 90.0967],
[ 27.6564, -8.6114, 35.4470],
[ 71.8508, -25.1529, 95.5165],
[ 69.7275, -20.1357, 86.9620],
[ 67.0907, -21.9245, 84.8853],
[ 77.3163, -25.5980, 92.7700],
[ 63.0082, -21.0345, 78.7311],
[ 68.0553, -22.4280, 84.8031],
[ 5.8148, -2.3171, 8.0620],
[103.3399, -35.1769, 130.7801],
[ 54.8769, -18.6822, 70.4657],
[ 58.4446, -18.9764, 75.5509],
[ 91.0071, -31.2706, 112.6401],
[ 84.6577, -29.2644, 104.6046],
[ 45.4887, -15.8309, 59.0498],
[ 56.3384, -18.9264, 78.8834],
[ 63.5109, -21.3169, 81.5144],
[ 79.4635, -29.8681, 100.5056],
[ 27.6559, -10.0517, 35.6012],
[ 76.3909, -24.1689, 93.6133],
[ 34.3802, -11.5272, 45.8650],
[ 60.3553, -20.1693, 76.5371],
[ 56.0590, -18.6468, 69.8981]], grad_fn=<AddmmBackward0>)
gather result
tensor([[ 59.3737],
[-10.4617],
[ 67.7197],
[ 94.7701],
[ nan],
[-29.2101],
[ 67.7193],
[-38.7149],
[-20.1968],
[ 66.0719],
[ 98.9463],
[107.6888],
[-20.1257],
[-35.4035],
[ 47.1067],
[ nan],
[ 55.6286],
[-18.7892],
[ 66.7059],
[-30.0302],
[ 1.6033],
[112.4620],
[ 74.7656],
[-15.7436],
[ 81.6596],
[-21.8001],
[ 35.0414],
[-23.1017],
[ 40.6284],
[ 68.6394],
[ 34.0150],
[ 34.0285],
[ 78.7965],
[-28.1487],
[ 67.6665],
[-20.2784],
[-21.5674],
[ 72.3873],
[-15.4678],
[ 85.6189],
[ nan],
[-24.8342],
[-29.5941],
[-23.3309],
[101.2122],
[-19.4720],
[-16.1262],
[ -9.3767],
[-23.4962],
[-11.7238],
[ 70.3627],
[-22.0783],
[-20.3157],
[ 67.2404],
[-20.7747],
[112.6491],
[-30.7755],
[-19.6988],
[ 50.0969],
[ 34.5749],
[ 88.2302],
[-21.4236],
[ -8.1834],
[ 73.8116],
[110.0389],
[ nan],
[112.4239],
[-16.1073],
[-32.1473],
[-30.7778],
[ 98.2536],
[ 98.3553],
[ 28.5688],
[107.0244],
[-16.9703],
[-29.7008],
[ 77.2257],
[-28.7558],
[-27.9047],
[ 67.1428],
[-12.7457],
[ 47.7178],
[-29.4141],
[ 59.9424],
[-22.1235],
[129.9478],
[-23.3980],
[-23.6567],
[ 75.9350],
[-10.3434],
[-19.4174],
[ 30.1057],
[ 85.3776],
[ 63.8175],
[ 46.0356],
[-29.5049],
[-22.5860],
[-22.0601],
[113.5105],
[-17.7980],
[ nan],
[-29.2494],
[ 76.9251],
[-32.8878],
[115.3978],
[-24.9450],
[ 35.4470],
[ 95.5165],
[ 86.9620],
[-21.9245],
[-25.5980],
[ 78.7311],
[-22.4280],
[ 5.8148],
[103.3399],
[ 70.4657],
[ 58.4446],
[ 91.0071],
[104.6046],
[ 45.4887],
[-18.9264],
[ 63.5109],
[ 79.4635],
[-10.0517],
[ 76.3909],
[ 34.3802],
[-20.1693],
[-18.6468]], grad_fn=<GatherBackward0>)
last squeeze
tensor([ 59.3737, -10.4617, 67.7197, 94.7701, nan, -29.2101, 67.7193,
-38.7149, -20.1968, 66.0719, 98.9463, 107.6888, -20.1257, -35.4035,
47.1067, nan, 55.6286, -18.7892, 66.7059, -30.0302, 1.6033,
112.4620, 74.7656, -15.7436, 81.6596, -21.8001, 35.0414, -23.1017,
40.6284, 68.6394, 34.0150, 34.0285, 78.7965, -28.1487, 67.6665,
-20.2784, -21.5674, 72.3873, -15.4678, 85.6189, nan, -24.8342,
-29.5941, -23.3309, 101.2122, -19.4720, -16.1262, -9.3767, -23.4962,
-11.7238, 70.3627, -22.0783, -20.3157, 67.2404, -20.7747, 112.6491,
-30.7755, -19.6988, 50.0969, 34.5749, 88.2302, -21.4236, -8.1834,
73.8116, 110.0389, nan, 112.4239, -16.1073, -32.1473, -30.7778,
98.2536, 98.3553, 28.5688, 107.0244, -16.9703, -29.7008, 77.2257,
-28.7558, -27.9047, 67.1428, -12.7457, 47.7178, -29.4141, 59.9424,
-22.1235, 129.9478, -23.3980, -23.6567, 75.9350, -10.3434, -19.4174,
30.1057, 85.3776, 63.8175, 46.0356, -29.5049, -22.5860, -22.0601,
113.5105, -17.7980, nan, -29.2494, 76.9251, -32.8878, 115.3978,
-24.9450, 35.4470, 95.5165, 86.9620, -21.9245, -25.5980, 78.7311,
-22.4280, 5.8148, 103.3399, 70.4657, 58.4446, 91.0071, 104.6046,
45.4887, -18.9264, 63.5109, 79.4635, -10.0517, 76.3909, 34.3802,
-20.1693, -18.6468], grad_fn=<SqueezeBackward1>)
Edit 1: print of actions_v
actions_v
tensor([2, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 2, 0, 2, 0, 1,
2, 0, 2, 1, 1, 0, 2, 1, 0, 0, 2, 1, 1, 1, 0, 1, 0, 1, 1, 2, 1, 1, 2, 1,
0, 2, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1, 2, 0, 0, 1, 1, 2, 0, 0, 2, 0, 0,
1, 1, 2, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 2, 0, 2, 0, 1, 1, 2, 1, 2, 2,
2, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 2, 1, 1, 0, 1, 0, 1, 2,
2, 1, 0, 2, 0, 0, 2, 1])
gather_nd takes inputs that have the same dimension as the input tensor, and will output a tensor of values being at those indices (which is what you want).
gather will output slices (but you can give as indice shape whatever you want, the output tensor will just be a bunch of slices that are structured accordingly to the shape of indices) which is not what you want.
So you should first make the indices match the dimensions of the initial matrix:
indices = tf.transpose(tf.stack((tf.range(tf.shape(state_action_values)[0]),actions_v)))
And then gather_nd:
state_action_values = tf.gather_nd(state_action_values,indices)
Keivan

Pyspark take, collect and first return different value

I run a Pyspark script that performs a map operation on a RDD. The result of such a map is a new RDD containing tuples, i.e. there is one tuple for each partition in the original RDD (there are 10 partitions). Each tuple contains two Numpy ndarray's and the partition ID.
The problem is take(10) and collect() return different values! Even, if a perform first() on the resulting RDD, the tuple is totally different from the first tuple returned by take or collect.
take(10) result is:
[(array([[ 0.19138815, -0.26613894, 0.0148395 , ..., -0.6887879 ,
0.01775263, 0.29900053],
[ 0.46834013, -0.41492677, -0.3986189 , ..., -0.09638319,
-0.27234066, 0.41824088],
[ 0.2235235 , -0.27003226, 0.05322047, ..., -0.3045229 ,
-0.30364496, 0.21981548],
...,
[ 0.40709212, -0.49947056, 0.36821032, ..., -0.27359277,
-0.1552616 , -0.10506155],
[ 0.42315334, -0.6249347 , -0.38093382, ..., -0.52247494,
0.12167282, 0.53337336],
[-0.12733063, -0.27234274, -0.05421005, ..., 0.3884521 ,
0.19977048, 0.11347781]], dtype=float32), array([[ 0.04737013, -0.13501787, -0.04363494, ..., -0.53189725,
-0.1747366 , 0.09700331],
[ 0.59339726, -0.622871 , -0.7482438 , ..., 0.01563211,
-0.24306081, 0.5259429 ],
[ 0.199726 , -0.29667157, 0.03788655, ..., -0.72914207,
-0.61997485, 0.4880442 ],
...,
[ 0.10908518, 0.15787347, 0.13822514, ..., -0.015803 ,
-0.05935803, 0.11415796],
[ 0.02415699, 0.21212797, -0.01504083, ..., -0.0880443 ,
0.03949256, -0.17959005],
[-0.08058456, -0.00894655, -0.10706384, ..., 0.00954069,
-0.18720922, -0.05665499]], dtype=float32), 0), (array([[ 0.53325355, -0.0129254 , 0.10924862, ..., 0.2797827 ,
0.44130138, -0.29074535],
[ 0.29487053, 0.20387554, 0.00834447, ..., 0.3034479 ,
-0.34347925, 0.48914096],
[-0.36053488, 0.2551153 , 0.23102154, ..., -0.08557958,
0.2305064 , -0.11637823],
...,
[ 0.50112355, -0.5487336 , 0.1381122 , ..., -0.17219128,
0.5784589 , -0.39060545],
[ 0.47822553, -0.21500733, -0.02590418, ..., 0.45222896,
-0.29980502, -0.4379743 ],
[-0.27317327, 0.47888 , 0.13328783, ..., 0.45453754,
-0.03382564, 0.28364402]], dtype=float32), array([[ 0.45505902, 0.20775506, 0.38936007, ..., 0.31039652,
0.0612184 , -0.34329525],
[ 0.44160137, 0.45554656, 0.24406506, ..., 0.20645882,
-0.55068386, 0.61949503],
[-0.9611948 , 0.48338398, 0.8936671 , ..., -0.07192033,
-0.04691654, -0.0583482 ],
...,
[ 0.27866694, 0.11791337, 0.00603435, ..., -0.0984261 ,
0.05514587, -0.16367936],
[ 0.24686542, -0.01605012, 0.15803055, ..., 0.14647359,
0.00465332, -0.21551773],
[ 0.01786835, 0.08164094, 0.0458132 , ..., 0.09466646,
-0.07061186, 0.04650302]], dtype=float32), 1), (array([[-0.30373037, 0.4013883 , 0.5544747 , ..., -0.24839583,
-0.434404 , -0.5419062 ],
[ 0.3205092 , -0.21219605, 0.23547144, ..., -0.4373149 ,
0.30616343, -0.45202586],
[-0.25409338, -0.14463192, 0.30881113, ..., -0.29998812,
-0.24947752, -0.18218543],
...,
[-0.33871105, 0.20521186, 0.38351473, ..., 0.33596635,
0.370852 , -0.46504658],
[-0.23807049, 0.1317612 , 0.07848132, ..., -0.32858378,
-0.541284 , -0.4595052 ],
[-0.17040105, 0.48929718, -0.39259502, ..., -0.15026243,
-0.19829535, 0.18581793]], dtype=float32), array([[-0.5295805 , 0.64539313, 0.47219488, ..., -0.4475045 ,
-0.31032106, -0.60634965],
[ 0.28920087, -0.2322041 , 0.313627 , ..., -0.71224356,
0.1539378 , -0.39970958],
[-0.39529595, -0.50494266, 0.61120296, ..., -0.94842255,
-0.7700451 , 0.03852249],
...,
[ 0.082367 , -0.09744105, 0.00574634, ..., 0.04329732,
0.08105459, -0.20032766],
[ 0.04620253, 0.18625231, -0.04047911, ..., -0.1547844 ,
-0.01560262, 0.00372486],
[ 0.0602592 , 0.078048 , 0.04372916, ..., -0.10393928,
-0.27185628, -0.05753115]], dtype=float32), 2), (array([[-0.10130739, -0.5432671 , 0.14230369, ..., -0.20037425,
-0.50981474, -0.39152429],
[-0.10439714, -0.19250502, -0.12469167, ..., 0.50656915,
0.41846293, -0.12511848],
[-0.12075678, 0.13746399, -0.10762265, ..., -0.33095708,
0.38831544, -0.3573719 ],
...,
[-0.33046186, 0.1675668 , -0.22495636, ..., 0.39853546,
-0.21838626, -0.44713587],
[ 0.1753454 , -0.16229314, 0.24015644, ..., -0.09867255,
-0.6676028 , -0.03644819],
[ 0.33300576, -0.28793842, 0.45033735, ..., 0.26602328,
0.18528135, -0.37736982]], dtype=float32), array([[-0.4201235 , -0.42547926, 0.3823224 , ..., -0.14358595,
-0.33159116, -0.47784004],
[-0.29869375, -0.31769726, 0.00777127, ..., 0.5764582 ,
0.160321 , -0.14145109],
[ 0.08045941, 0.31604633, -0.22956751, ..., -0.551948 ,
0.8479948 , -1.065378 ],
...,
[ 0.02872542, -0.23239408, 0.10584452, ..., -0.19653268,
0.01299276, 0.01154038],
[-0.11145593, -0.04056952, 0.17250092, ..., 0.11815882,
-0.09970374, -0.14516412],
[ 0.03227303, 0.08426446, -0.04518193, ..., -0.04804583,
0.11866879, 0.07962584]], dtype=float32), 3), (array([[-0.3442908 , -0.47924808, -0.17909145, ..., 0.57778424,
-0.20964314, -0.38468337],
[-0.1493048 , 0.22245084, -0.19882523, ..., 0.38685095,
-0.5174358 , -0.4666229 ],
[-0.36909524, -0.32142702, -0.183554 , ..., 0.16354944,
-0.31103647, -0.12249314],
...,
[-0.45103338, -0.04599327, 0.3811725 , ..., -0.28528866,
-0.47804165, -0.08458076],
[ 0.2814856 , -0.07282971, -0.6203528 , ..., 0.1307708 ,
0.01159024, -0.28599057],
[-0.20529339, 0.39615756, -0.12310734, ..., 0.21978702,
0.09362441, 0.282171 ]], dtype=float32), array([[-3.9894256e-01, -2.3258871e-01, -3.5374862e-01, ...,
5.5857885e-01, 1.3151944e-02, -3.4470174e-01],
[ 3.7879427e-03, 2.3241083e-01, -1.8348652e-01, ...,
2.8142187e-01, -5.6694925e-01, -7.0757174e-01],
[-1.0855644e+00, -1.0010500e+00, -5.0151312e-01, ...,
1.0360012e+00, -8.4597129e-01, 3.6014456e-01],
...,
[ 9.6085869e-02, 4.7818022e-03, 9.2197888e-02, ...,
5.7999889e-04, -4.3681998e-02, -1.0792557e-01],
[-2.1461031e-01, -7.9450890e-02, -3.6970940e-03, ...,
1.2008668e-01, 8.1481531e-02, -7.4978117e-03],
[ 2.2965204e-02, -7.5380646e-02, -1.5798187e-01, ...,
2.3251076e-01, 2.0147276e-01, -5.9070695e-02]], dtype=float32), 4), (array([[ 0.34773657, 0.1271629 , 0.35084245, ..., -0.4163662 ,
-0.2032509 , 0.16763268],
[ 0.25300634, -0.24963312, 0.03022493, ..., -0.48368084,
-0.5438243 , -0.2523892 ],
[ 0.08074404, 0.34778172, -0.19389851, ..., 0.00394286,
-0.03249731, 0.21850894],
...,
[-0.28799132, 0.46877176, 0.13380927, ..., 0.21133232,
-0.20794198, 0.04732013],
[ 0.533321 , 0.08838256, 0.52776694, ..., -0.3776623 ,
0.3234356 , 0.49971387],
[ 0.01111342, 0.2225688 , 0.5407952 , ..., -0.5621694 ,
-0.48505995, 0.45497927]], dtype=float32), array([[ 0.31677225, 0.3403609 , 0.43394235, ..., -0.22031932,
-0.19851643, 0.07987656],
[ 0.23734477, -0.08880972, 0.29560354, ..., -0.29227307,
-0.6075761 , 0.0549405 ],
[ 0.51728237, 0.92628396, -0.40312177, ..., 0.47099453,
0.1400124 , 0.5190122 ],
...,
[-0.01423888, -0.00644734, -0.01985722, ..., 0.00387737,
0.08140853, -0.04754239],
[-0.20655577, 0.07879162, 0.16483389, ..., 0.03559271,
-0.08346304, 0.02901335],
[ 0.12661955, -0.14012276, 0.13914892, ..., 0.09200166,
-0.01300973, 0.09209644]], dtype=float32), 5), (array([[ 0.6097177 , 0.48289144, 0.107826 , ..., 0.14263372,
0.43684664, 0.37251818],
[-0.47600058, -0.32621023, -0.55021507, ..., -0.48766413,
0.19021659, -0.25467217],
[-0.26290026, 0.19371034, -0.3084416 , ..., -0.38468444,
-0.24624074, -0.1248633 ],
...,
[ 0.29716069, -0.17764345, 0.06396675, ..., -0.51036185,
-0.16205388, -0.11298727],
[ 0.11154262, 0.17699954, -0.55426383, ..., 0.36997363,
0.10982917, -0.16837421],
[-0.0491058 , 0.13139851, 0.38043728, ..., -0.20505017,
0.15700234, 0.5629593 ]], dtype=float32), array([[ 5.4628938e-01, 4.5117161e-01, 3.2670802e-01, ...,
1.9686517e-01, 3.2789686e-01, 4.5343029e-01],
[-6.4986098e-01, -2.4272519e-01, -7.6540470e-01, ...,
-7.3495942e-01, 2.8457634e-02, -1.3791913e-01],
[-7.1882963e-01, 6.7012483e-01, -6.4978844e-01, ...,
-6.2161821e-01, -6.2733275e-01, -3.5520321e-01],
...,
[-9.8777540e-02, -3.5368185e-02, 3.7035108e-02, ...,
3.2906037e-02, -3.7386462e-02, -7.3537357e-02],
[ 7.1388625e-02, 1.7618576e-01, -3.5727687e-02, ...,
-1.8145926e-01, -8.1750348e-02, 1.6553156e-02],
[-4.8265442e-02, -1.0864433e-01, -7.2953522e-02, ...,
5.3536858e-02, 1.4538332e-04, 1.6074999e-01]], dtype=float32), 6), (array([[-0.60614747, 0.12892431, -0.3011826 , ..., 0.25248086,
-0.52810246, 0.14257856],
[-0.42474708, -0.37544054, -0.3886031 , ..., -0.35798463,
-0.18877436, 0.3048291 ],
[-0.17884058, -0.37839454, 0.3589297 , ..., 0.05497296,
-0.06037642, -0.4278129 ],
...,
[ 0.1893682 , 0.22125317, -0.20827802, ..., -0.29234493,
-0.1302274 , 0.02801383],
[ 0.1747711 , -0.0879383 , -0.395539 , ..., -0.38479805,
-0.61469847, 0.00207329],
[ 0.3819209 , 0.49023125, -0.42264247, ..., 0.1528128 ,
0.45578462, 0.4668125 ]], dtype=float32), array([[-5.28453410e-01, 1.38694681e-02, -2.60350943e-01, ...,
2.34538317e-01, -4.43873584e-01, 1.47595644e-01],
[-5.25614679e-01, -4.94065553e-01, -4.03498739e-01, ...,
-4.45649236e-01, -1.77955553e-01, 3.10622215e-01],
[ 1.01587474e-01, -8.57380509e-01, 8.03865731e-01, ...,
2.06392854e-01, -1.75320432e-01, -7.59775221e-01],
...,
[ 8.77973512e-02, -2.19333827e-01, -1.42306268e-01, ...,
-1.03031479e-01, 6.18839522e-06, 1.56334475e-01],
[-1.23811606e-02, -3.57912146e-02, -7.11395666e-02, ...,
4.88635562e-02, 1.20033674e-01, -2.04370469e-02],
[ 1.00751802e-01, -4.78806317e-01, 1.95426181e-01, ...,
-5.01195967e-01, -3.37407261e-01, 3.57391506e-01]], dtype=float32), 7), (array([[ 0.49053288, -0.04683368, -0.0879433 , ..., 0.6166893 ,
-0.5472508 , 0.5924473 ],
[-0.40197268, 0.21878959, 0.47748646, ..., -0.27519724,
0.3854015 , -0.04976773],
[ 0.24380569, -0.04194092, -0.22590604, ..., 0.35376453,
0.2546404 , -0.20142618],
...,
[-0.5142191 , -0.2877738 , -0.47166097, ..., 0.48306477,
0.26082426, -0.4445646 ],
[-0.25510445, 0.00508366, -0.5078671 , ..., -0.27604827,
-0.08479042, 0.04850767],
[-0.37874135, 0.49107817, 0.11259978, ..., 0.3188926 ,
0.15944068, 0.0829725 ]], dtype=float32), array([[ 0.5134689 , -0.18948251, -0.05751067, ..., 0.48557195,
-0.3173096 , 0.372445 ],
[-0.43809155, -0.08225048, 0.7466321 , ..., -0.5873075 ,
0.7402181 , -0.09714872],
[ 0.68586934, 0.1767921 , 0.07150997, ..., 0.8427489 ,
0.45624557, -0.32352042],
...,
[ 0.02128473, -0.00579748, 0.12406835, ..., 0.08425785,
0.21446604, 0.09736192],
[ 0.04778935, -0.12865652, 0.00840133, ..., -0.00293668,
0.06652898, -0.03904634],
[ 0.02595548, 0.11789754, -0.02631662, ..., 0.03307972,
0.24783 , -0.12637296]], dtype=float32), 8), (array([[-0.09508003, 0.5282958 , 0.04243381, ..., -0.0648976 ,
-0.56710106, -0.06858341],
[-0.23444827, -0.17710732, -0.29838824, ..., -0.03992304,
-0.51030684, -0.34101528],
[ 0.00919831, 0.00875767, -0.09158213, ..., 0.26422626,
-0.19114144, 0.2717857 ],
...,
[-0.3377122 , 0.49312168, -0.50667596, ..., -0.33557674,
-0.3865259 , 0.4990052 ],
[ 0.17765625, -0.06699899, 0.29469523, ..., 0.612583 ,
0.13147196, -0.27174017],
[-0.56627136, -0.25801656, -0.28928643, ..., 0.18859185,
-0.46310693, 0.23317206]], dtype=float32), array([[ 0.22075664, 0.41444212, -0.4506032 , ..., -0.16359589,
-0.569371 , -0.13674185],
[-0.19642216, -0.36554584, -0.45241717, ..., -0.06579936,
-0.656618 , -0.11742882],
[-0.04687649, 0.10208729, -0.060711 , ..., 0.78385997,
-0.17697819, 0.75614196],
...,
[ 0.2150026 , 0.15821525, -0.16107246, ..., -0.20459318,
-0.20312549, 0.11470713],
[ 0.06135871, 0.06140292, -0.11083869, ..., 0.14165913,
0.05527819, 0.0520999 ],
[ 0.13505958, 0.06166949, -0.11568853, ..., -0.04749477,
0.02072733, -0.05888995]], dtype=float32), 9)]
collect() result is:
[(array([[ 0.18516347, -0.33975708, -0.46829244, ..., 0.498327 ,
-0.1628269 , 0.5171599 ],
[ 0.24843855, 0.43924475, 0.43121427, ..., -0.3605212 ,
-0.2543247 , -0.35761902],
[ 0.03349265, 0.28567392, -0.3129074 , ..., 0.30228034,
0.33539015, 0.28145155],
...,
[ 0.39538023, -0.11668223, 0.23590142, ..., -0.39222914,
-0.34792763, 0.43729994],
[-0.37299404, -0.40583754, -0.41405225, ..., 0.3708834 ,
0.6067088 , 0.5815965 ],
[-0.5297639 , 0.09037948, 0.06255247, ..., 0.55813074,
0.2599809 , 0.2930913 ]], dtype=float32), array([[-0.05441456, -0.08513332, -0.47554415, ..., 0.46570715,
0.11365455, 0.641596 ],
[ 0.25915185, 0.6784206 , 0.5428535 , ..., -0.3223893 ,
-0.17784661, -0.3021973 ],
[ 0.4007858 , 0.48166505, -0.7551351 , ..., 0.5893394 ,
0.5379706 , 0.5853663 ],
...,
[ 0.03785443, -0.02449032, 0.07482295, ..., 0.14570121,
0.02578176, 0.11021709],
[ 0.02157725, 0.20236807, -0.25889152, ..., -0.2123813 ,
0.11124409, 0.05835798],
[-0.2482176 , 0.1100187 , -0.16054511, ..., 0.15483692,
-0.01258843, -0.000899 ]], dtype=float32), 0), (array([[-0.06023056, 0.34326535, 0.01615176, ..., 0.50180113,
0.35740197, -0.3607464 ],
[-0.37374282, -0.05733229, -0.10494906, ..., 0.10067802,
-0.30181995, 0.19373518],
[-0.21093842, -0.35539758, 0.2722222 , ..., -0.13212524,
0.15457118, 0.29343936],
...,
[-0.3844776 , 0.29577827, 0.23207994, ..., -0.2748728 ,
0.05118364, -0.43278512],
[ 0.18988602, 0.15946351, -0.37208527, ..., 0.18980268,
0.26914784, 0.57002 ],
[-0.19685334, -0.00215623, -0.50676346, ..., -0.25601804,
0.43306062, -0.45977998]], dtype=float32), array([[ 0.07212897, 0.37888303, 0.14216022, ..., 0.3568563 ,
0.31809187, -0.30161127],
[-0.14504915, 0.09403719, -0.24099208, ..., 0.1194509 ,
-0.571604 , 0.3073006 ],
[-0.34506813, -0.5373435 , 0.39612344, ..., -0.17277275,
-0.15978633, 0.9480076 ],
...,
[ 0.33411705, 0.3616483 , -0.00284358, ..., -0.1570069 ,
-0.10693 , 0.11339971],
[ 0.03463184, 0.02923653, -0.07571009, ..., -0.01076985,
-0.24661644, 0.07400434],
[ 0.03500444, 0.02041529, 0.11345199, ..., -0.21085429,
-0.11910766, -0.16162656]], dtype=float32), 1), (array([[ 0.40148696, 0.44860718, -0.13891219, ..., 0.38478658,
0.0354558 , 0.5507143 ],
[ 0.35727167, 0.33523828, -0.40109852, ..., -0.12698223,
-0.48993534, 0.4485451 ],
[-0.37534824, -0.07898732, -0.28266475, ..., 0.0923453 ,
0.20398574, 0.46967202],
...,
[-0.31968838, 0.47479796, -0.49929148, ..., -0.23865293,
-0.24262336, -0.06511039],
[ 0.2098573 , -0.5782443 , 0.0044039 , ..., -0.13356705,
-0.5997722 , 0.24789433],
[ 0.18989192, -0.41790476, -0.5493083 , ..., -0.04386537,
0.14099114, -0.3851897 ]], dtype=float32), array([[ 0.35465854, 0.5298055 , 0.09639163, ..., 0.42402148,
0.1006598 , 0.42896628],
[ 0.49630994, 0.53968257, -0.45613742, ..., 0.11758496,
-0.82799774, 0.47129843],
[-0.11091773, -0.126068 , -0.94792765, ..., -0.39254257,
0.49629924, 0.90804875],
...,
[ 0.09986763, 0.18694244, 0.04551417, ..., -0.00185746,
0.04787954, 0.14079888],
[ 0.01976901, 0.01671817, 0.02434383, ..., -0.05640491,
0.03537085, -0.08094196],
[-0.12894829, 0.15826981, -0.09516723, ..., -0.11121278,
-0.17831786, -0.00805143]], dtype=float32), 2), (array([[-0.16927287, -0.3801918 , -0.32327962, ..., -0.51245123,
0.41986853, 0.18242987],
[-0.41638187, -0.06312063, -0.40284333, ..., 0.26918623,
-0.4305522 , -0.4801858 ],
[ 0.00497885, -0.22712015, -0.35257223, ..., 0.02938372,
0.32673585, 0.176891 ],
...,
[ 0.22629777, 0.39141867, 0.3272797 , ..., -0.45520803,
0.17408061, -0.27852598],
[-0.24445221, -0.35762975, -0.39768136, ..., 0.26196685,
0.17221238, 0.22423406],
[ 0.38731757, -0.45889175, 0.3848555 , ..., 0.469341 ,
0.2884723 , 0.4584588 ]], dtype=float32), array([[-0.38282454, -0.33555806, -0.38107285, ..., -0.4294134 ,
0.40941086, 0.19802842],
[-0.5812499 , -0.08969598, -0.5220119 , ..., 0.35569692,
-0.436198 , -0.78949076],
[-0.6598932 , -0.61103916, -0.8173852 , ..., 0.39758506,
0.77731544, 0.38008815],
...,
[ 0.16403008, -0.06079637, -0.02834259, ..., -0.10361861,
0.10504336, 0.21878755],
[-0.04492131, -0.17708012, -0.2025333 , ..., 0.17902693,
-0.15750957, -0.23726523],
[-0.07494884, -0.06165536, -0.10450385, ..., 0.17355536,
0.01258762, -0.05638846]], dtype=float32), 3), (array([[ 0.00608851, -0.5132972 , -0.36956814, ..., -0.5321781 ,
-0.26002124, 0.18518737],
[-0.21294856, -0.30260456, 0.11435414, ..., -0.52438897,
-0.46490332, -0.43702644],
[ 0.32582685, -0.25917143, -0.34966806, ..., -0.29239386,
-0.09908075, 0.32888958],
...,
[-0.25709772, 0.32678795, 0.00447032, ..., -0.04280795,
0.01051257, 0.36575 ],
[-0.07194624, -0.245111 , 0.5667958 , ..., 0.41409835,
-0.5799336 , -0.4858021 ],
[-0.11811656, 0.52633476, 0.40527648, ..., 0.40293694,
-0.4926284 , -0.14098077]], dtype=float32), array([[-0.11553568, -0.3299477 , -0.50079507, ..., -0.44186658,
0.01149882, 0.01778618],
[-0.3005174 , -0.26895422, 0.08186761, ..., -0.7348857 ,
-0.5810445 , -0.53310806],
[ 0.5055336 , -0.35796392, -0.90873146, ..., -0.78193223,
-0.04517283, 0.7075767 ],
...,
[-0.02022595, 0.01297146, -0.02664614, ..., 0.10258275,
-0.06243805, -0.07688553],
[-0.24571168, 0.18370496, 0.03886681, ..., -0.03063186,
-0.04676892, -0.10450852],
[ 0.03074093, -0.045911 , 0.07248624, ..., -0.05876327,
0.06366935, -0.01161662]], dtype=float32), 4), (array([[-0.28593373, 0.13395616, 0.48233178, ..., 0.508933 ,
-0.19197209, 0.3298264 ],
[-0.04904724, 0.34900156, 0.32834592, ..., -0.42706388,
-0.39813402, 0.14217453],
[ 0.13526852, 0.3745679 , 0.12265893, ..., 0.30098978,
0.15158501, 0.2164896 ],
...,
[ 0.21077481, 0.31915814, -0.32008937, ..., -0.48824465,
0.15033281, -0.55469203],
[-0.33247608, -0.05625575, 0.43155706, ..., 0.34942424,
0.04699863, 0.17167164],
[-0.32446697, 0.3883351 , -0.18434255, ..., -0.481489 ,
0.38606554, -0.31928998]], dtype=float32), array([[-0.21347475, 0.1707647 , 0.40377334, ..., 0.21519852,
-0.2119801 , 0.19800745],
[ 0.11622372, 0.4918646 , 0.4055998 , ..., -0.65552425,
-0.5142339 , 0.01027841],
[ 0.5708596 , 0.58031136, 0.5869403 , ..., 0.7259039 ,
0.10178877, 0.09589735],
...,
[-0.04035015, 0.08506791, 0.00343785, ..., 0.06102315,
0.07513183, -0.05011833],
[ 0.15481526, 0.14573523, -0.04516461, ..., -0.14253813,
-0.0463484 , -0.2259047 ],
[ 0.22652309, 0.2351952 , -0.03388155, ..., -0.04040325,
-0.17493977, -0.28690276]], dtype=float32), 5), (array([[ 0.01241684, 0.3848157 , -0.39109015, ..., 0.5478949 ,
-0.56430525, 0.6666779 ],
[ 0.14730261, -0.01858082, -0.3315102 , ..., 0.2657176 ,
-0.33017242, 0.21192974],
[ 0.33523917, -0.18178591, -0.3066008 , ..., 0.20106438,
-0.22945583, -0.03762559],
...,
[-0.2540486 , 0.18161409, -0.18843989, ..., -0.2343464 ,
-0.09964091, -0.06174641],
[-0.04365243, 0.5485354 , 0.20301312, ..., 0.46485335,
-0.5898763 , -0.32282397],
[-0.48032835, -0.44875026, 0.27545917, ..., -0.38980302,
-0.42448705, -0.28787732]], dtype=float32), array([[ 0.12385425, 0.34021133, -0.5947283 , ..., 0.5262172 ,
-0.50784826, 0.6099582 ],
[ 0.36343968, -0.00910554, -0.4031666 , ..., 0.29984346,
-0.51017505, 0.19207458],
[ 0.7371366 , -0.05033325, -0.65316683, ..., 0.6836462 ,
-0.82514614, 0.05400744],
...,
[ 0.00850144, 0.14323875, -0.17175199, ..., 0.02427784,
0.08420605, 0.08254603],
[-0.05369973, 0.1430688 , -0.18583198, ..., 0.22647178,
-0.23855914, 0.11469093],
[-0.16125312, 0.00549565, -0.1524472 , ..., 0.02182539,
-0.07359254, 0.13050811]], dtype=float32), 6), (array([[-0.04644716, 0.05705595, -0.24267559, ..., -0.21403529,
-0.06703684, 0.41887376],
[-0.29370484, -0.39780775, -0.3568661 , ..., -0.09881599,
0.07795003, 0.38119403],
[-0.3645612 , 0.10963462, -0.347853 , ..., -0.27757406,
-0.27060333, 0.3227043 ],
...,
[ 0.29615575, -0.4612898 , 0.05339388, ..., -0.08740558,
-0.25723857, -0.49486127],
[ 0.05623814, 0.4080824 , -0.24716274, ..., -0.09058192,
0.1756261 , -0.3089906 ],
[-0.5414306 , 0.4408762 , 0.42941254, ..., 0.15542302,
0.5396825 , -0.2709453 ]], dtype=float32), array([[-3.42568368e-01, 1.64543867e-01, 1.99411588e-04, ...,
-2.34075099e-01, 4.41714190e-02, 3.03957969e-01],
[-4.59636778e-01, -4.12208438e-01, -4.32887614e-01, ...,
-5.51471636e-02, 2.10178539e-01, 5.81714928e-01],
[-1.12530768e+00, 6.36167228e-01, -6.05338633e-01, ...,
-5.48392713e-01, -2.09195793e-01, 1.01132858e+00],
...,
[-1.30258754e-01, -4.16335762e-02, 1.45700663e-01, ...,
-3.17536071e-02, 3.02174967e-02, 1.10822864e-01],
[-1.39803439e-01, 3.05877943e-02, 1.05636232e-01, ...,
1.10526226e-01, 6.08176775e-02, -5.73274679e-02],
[-2.55209617e-02, 1.39828652e-01, -5.54599836e-02, ...,
1.40946835e-01, 1.95314527e-01, -7.23765837e-03]], dtype=float32), 7), (array([[-0.07237607, -0.06369619, -0.57799906, ..., -0.3011678 ,
-0.3869047 , -0.5708126 ],
[-0.00725966, -0.04352329, -0.14471681, ..., -0.47405225,
-0.11870398, 0.44799381],
[ 0.36965963, 0.19295754, -0.25880384, ..., -0.27418908,
-0.2637073 , -0.25275636],
...,
[ 0.06180732, 0.10883695, -0.20686714, ..., -0.4045689 ,
-0.10775824, -0.00597983],
[ 0.17705184, 0.2853461 , 0.38804924, ..., 0.00480051,
0.23195331, 0.5900061 ],
[ 0.33405438, -0.05846346, -0.49157378, ..., 0.13280089,
0.4277615 , -0.43489072]], dtype=float32), array([[-0.02955375, 0.03905029, -0.5314531 , ..., -0.3317717 ,
-0.32801694, -0.5832374 ],
[-0.15554425, 0.01338848, -0.30303237, ..., -0.5683995 ,
0.07281558, 0.39870417],
[ 0.57431275, 0.48248613, -0.36046496, ..., -0.740754 ,
-0.4101879 , -0.46861094],
...,
[-0.05291473, -0.00626671, -0.16552408, ..., -0.21676618,
0.00198667, 0.15837853],
[-0.20605578, 0.02732588, 0.05644984, ..., -0.14183098,
0.11027621, 0.00878341],
[ 0.01149986, 0.02788752, -0.07346534, ..., 0.02313588,
0.0365326 , 0.0427165 ]], dtype=float32), 8), (array([[ 0.32586667, 0.6187046 , 0.3162743 , ..., 0.42238498,
-0.08919488, -0.23198491],
[ 0.1516603 , 0.04449864, 0.10896807, ..., -0.49212578,
0.14955266, -0.04248938],
[-0.23296818, -0.16775374, 0.41552317, ..., -0.27849862,
0.08736038, 0.36777073],
...,
[ 0.17306753, 0.4053642 , -0.063707 , ..., 0.09530412,
-0.46045092, 0.42887986],
[-0.63025045, 0.3146556 , 0.148895 , ..., 0.48645812,
0.27349007, -0.0574788 ],
[-0.00731906, -0.03560491, -0.3711313 , ..., 0.2597622 ,
0.44751585, 0.30753264]], dtype=float32), array([[ 0.48764184, 0.54429394, 0.461604 , ..., 0.5304342 ,
0.20013283, -0.23899378],
[ 0.15820505, -0.131124 , -0.02082682, ..., -0.58288604,
0.04917484, -0.06569459],
[-0.74465793, -0.29858342, 1.2702694 , ..., -0.80466145,
0.45967343, 0.913658 ],
...,
[-0.02027037, -0.06839398, -0.14130257, ..., -0.14568064,
-0.02432712, 0.02081295],
[-0.03228899, -0.05201127, 0.1773543 , ..., 0.05981961,
-0.13131005, 0.00199337],
[-0.14305267, 0.02065172, 0.01255422, ..., 0.02369817,
0.0101981 , 0.13320476]], dtype=float32), 9)]
Any idea about what this weird behaviour? Maybe something related to byte order in Numpy ndarray's?
My fault. I forgot about Spark lazy evaluation (map() is a transformation, not an action as take() and collect() are; tranformations are not actually evaluated until an action is called).
This, combined with the fact my mapping function seems to be not idempotent (i.e. it does not return the same value upon executions) leads to following behaviour: calling take(10) actually evaluates the map() for a first time (returning a value), and calling collect() actually evaluates the map() for a second time (and returning a totally different value).

Tensorflow, Reshape like a convolution

I have a matrix [3,3,256], my final output must be [4,2,2,256], I have to use a reshape like a 'convolution' without changing the values. (In this case using a filter 2x2). Is there a method to do this using tensorflow?
If I understand your question correctly, you want to store the original values redundantly in the new structure, like this (without the last dim of 256):
[ [ 1 2 3 ] [ [ 1 2 ] [ [ 2 3 ] [ [ 4 5 ] [ [ 5 6 ]
[ 4 5 6 ] => [ 4 5 ] ], [ 5 6 ] ], [ 7 8 ] ], [ 8 9 ] ]
[ 7 8 9 ] ]
If yes, you can use indexing, like this, with x being the original tensor, and then stack them:
x2 = []
for i in xrange( 2 ):
for j in xrange( 2 ):
x2.append( x[ i : i + 2, j : j + 2, : ] )
y = tf.stack( x2, axis = 0 )
Based on your comment, if you really want to avoid using any loops, you might utilize the tf.extract_image_patches, like below (tested code) but you should run some tests because this might actually be much worse than the above in terms of efficiency and perfomance:
import tensorflow as tf
sess = tf.Session()
x = tf.constant( [ [ [ 1, -1 ], [ 2, -2 ], [ 3, -3 ] ],
[ [ 4, -4 ], [ 5, -5 ], [ 6, -6 ] ],
[ [ 7, -7 ], [ 8, -8 ], [ 9, -9 ] ] ] )
xT = tf.transpose( x, perm = [ 2, 0, 1 ] ) # have to put channel dim as batch for tf.extract_image_patches
xTE = tf.expand_dims( xT, axis = -1 ) # extend dims to have fake channel dim
xP = tf.extract_image_patches( xTE, ksizes = [ 1, 2, 2, 1 ],
strides = [ 1, 1, 1, 1 ], rates = [ 1, 1, 1, 1 ], padding = "VALID" )
y = tf.transpose( xP, perm = [ 3, 1, 2, 0 ] ) # move dims back to original and new dim up front
print( sess.run(y) )
Output (horizontal separator lines added manually for readability):
[[[[ 1 -1]
[ 2 -2]]
[[ 4 -4]
[ 5 -5]]]
[[[ 2 -2]
[ 3 -3]]
[[ 5 -5]
[ 6 -6]]]
[[[ 4 -4]
[ 5 -5]]
[[ 7 -7]
[ 8 -8]]]
[[[ 5 -5]
[ 6 -6]]
[[ 8 -8]
[ 9 -9]]]]
I have a similar problem with you and I found that in tf.contrib.kfac.utils there is a function called extract_convolution_patches. Suppose you have a tensor X with shape (1, 3, 3, 256) where the initial 1 marks batch size, you can call
Y = tf.contrib.kfac.utils.extract_convolution_patches(X, (2, 2, 256, 1), padding='VALID')
Y.shape # (1, 2, 2, 2, 2, 256)
The first two 2's will be the number of your output filters (makes up the 4 in your description). The latter two 2's will be the shape of the filters. You can then call
Y = tf.reshape(Y, [4,2,2,256])
to get your final result.