I am migrating a code from pytorch to tensorflow, and in the function that calculates the loss, I have the below line that I need to migrate to tensorflow.
state_action_values = net(t_states_features).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
I found tf.gather and tf.gather_nd and I am not sure which is more suitable and how it could be used, also unsqueeze's alternative is maybe tf.expand_dims?
In an attempt to get a clearer view of the line's result, I split it into multiple parts with print statements.
print("net result")
state_action_values = net(t_states_features)
print(state_action_values)
print("gather result")
state_action_values = state_action_values.gather(1, actions_v.unsqueeze(-1))
print(state_action_values)
print("last squeeze")
state_action_values = state_action_values.squeeze(-1)
net result
tensor([[ 45.6878, -14.9495, 59.3737],
[ 33.5737, -10.4617, 39.0078],
[ 67.7197, -22.8818, 85.7977],
[ 94.7701, -33.2053, 120.5519],
[ nan, nan, nan],
[ 84.7324, -29.2101, 108.0821],
[ 67.7193, -22.7702, 86.9558],
[113.6835, -38.7149, 142.6167],
[ 61.9260, -20.1968, 79.8010],
[ 51.6152, -17.7391, 66.0719],
[ 73.6565, -21.5699, 98.9463],
[ 84.0761, -26.5016, 107.6888],
[ 60.9459, -20.1257, 76.4105],
[103.2883, -35.4035, 130.4503],
[ 37.1156, -13.5180, 47.1067],
[ nan, nan, nan],
[ 55.6286, -18.5239, 71.9837],
[ 55.3858, -18.7892, 71.1197],
[ 50.2419, -17.2959, 66.7059],
[ 82.5715, -30.0302, 108.4984],
[ -0.8662, -1.1861, 1.6033],
[112.4620, -38.6416, 142.4556],
[ 57.8702, -19.8080, 74.7656],
[ 45.8418, -15.7436, 57.3367],
[ 81.6596, -27.5002, 104.6002],
[ 57.1507, -21.8001, 67.7933],
[ 35.0414, -11.8199, 47.6573],
[ 67.7085, -23.1017, 85.4623],
[ 40.6284, -12.4578, 58.9603],
[ 68.6394, -23.1481, 87.0832],
[ 27.0549, -8.6635, 34.0150],
[ 25.4071, -8.5511, 34.0285],
[ 62.9161, -22.1693, 78.7965],
[ 85.4505, -28.1487, 108.6252],
[ 67.6665, -23.2376, 85.7117],
[ 60.7806, -20.2784, 77.1022],
[ 66.5209, -21.5674, 88.5561],
[ 61.6637, -20.9891, 72.3873],
[ 45.1634, -15.4678, 61.4886],
[ 66.8119, -23.1250, 85.6189],
[ nan, nan, nan],
[ 67.8166, -24.8342, 84.6706],
[ 86.2114, -29.5941, 107.8025],
[ 66.2716, -23.3309, 83.9700],
[101.2122, -35.3554, 127.4772],
[ 61.0749, -19.4720, 78.5588],
[ 50.4058, -16.1262, 63.1010],
[ 27.7543, -9.3767, 35.7448],
[ 67.7810, -23.4962, 83.6030],
[ 35.0103, -11.7238, 44.7983],
[ 55.7402, -19.0223, 70.3627],
[ 67.9733, -22.0783, 85.1893],
[ 60.5253, -20.3157, 79.7312],
[ 67.2404, -21.5205, 81.4499],
[ 57.9502, -20.7747, 70.9109],
[ 87.6536, -31.4256, 112.6491],
[ 90.3668, -30.7755, 116.6192],
[ 59.0660, -19.6988, 75.0723],
[ 50.0969, -17.4135, 62.6556],
[ 28.8703, -9.0950, 34.5749],
[ 68.4053, -22.0715, 88.2302],
[ 69.1397, -21.4236, 84.7833],
[ 23.8506, -8.1834, 30.8318],
[ 58.4296, -20.2432, 73.8116],
[ 87.5317, -29.0606, 110.0389],
[ nan, nan, nan],
[ 88.6387, -30.6154, 112.4239],
[ 51.6089, -16.1073, 66.2757],
[ 94.3989, -32.1473, 119.0358],
[ 82.7449, -30.7778, 102.8537],
[ 74.3067, -26.6585, 98.2536],
[ 77.0881, -26.5706, 98.3553],
[ 28.5688, -9.2949, 41.1165],
[ 86.1560, -26.9364, 107.0244],
[ 41.8914, -16.9703, 57.3840],
[ 88.8886, -29.7008, 108.2697],
[ 61.1243, -20.7566, 77.2257],
[ 85.1174, -28.7558, 107.3853],
[ 81.7256, -27.9047, 104.5006],
[ 51.2663, -16.5880, 67.1428],
[ 46.9150, -12.7457, 61.3240],
[ 36.1758, -12.9769, 47.7178],
[ 85.5846, -29.4141, 107.9649],
[ 59.9424, -20.8349, 75.3359],
[ 62.6516, -22.1235, 81.6903],
[104.7664, -34.5876, 129.9478],
[ 64.4671, -23.3980, 83.9093],
[ 69.6928, -23.6567, 89.6024],
[ 60.4407, -19.6136, 75.9350],
[ 33.4921, -10.3434, 44.9537],
[ 57.9112, -19.4174, 74.3050],
[ 24.8262, -9.3637, 30.1057],
[ 85.3776, -28.9097, 110.1310],
[ 63.8175, -22.3843, 81.0308],
[ 34.6040, -12.3217, 46.0356],
[ 88.3740, -29.5049, 110.2897],
[ 66.8196, -22.5860, 85.5386],
[ 58.9767, -22.0601, 78.7086],
[ 83.2090, -26.3499, 113.5105],
[ 54.8450, -17.7980, 68.1161],
[ nan, nan, nan],
[ 85.0846, -29.2494, 107.6780],
[ 76.9251, -26.2295, 98.4755],
[ 98.2907, -32.8878, 124.9192],
[ 91.1387, -30.8262, 115.3978],
[ 73.1062, -24.9450, 90.0967],
[ 27.6564, -8.6114, 35.4470],
[ 71.8508, -25.1529, 95.5165],
[ 69.7275, -20.1357, 86.9620],
[ 67.0907, -21.9245, 84.8853],
[ 77.3163, -25.5980, 92.7700],
[ 63.0082, -21.0345, 78.7311],
[ 68.0553, -22.4280, 84.8031],
[ 5.8148, -2.3171, 8.0620],
[103.3399, -35.1769, 130.7801],
[ 54.8769, -18.6822, 70.4657],
[ 58.4446, -18.9764, 75.5509],
[ 91.0071, -31.2706, 112.6401],
[ 84.6577, -29.2644, 104.6046],
[ 45.4887, -15.8309, 59.0498],
[ 56.3384, -18.9264, 78.8834],
[ 63.5109, -21.3169, 81.5144],
[ 79.4635, -29.8681, 100.5056],
[ 27.6559, -10.0517, 35.6012],
[ 76.3909, -24.1689, 93.6133],
[ 34.3802, -11.5272, 45.8650],
[ 60.3553, -20.1693, 76.5371],
[ 56.0590, -18.6468, 69.8981]], grad_fn=<AddmmBackward0>)
gather result
tensor([[ 59.3737],
[-10.4617],
[ 67.7197],
[ 94.7701],
[ nan],
[-29.2101],
[ 67.7193],
[-38.7149],
[-20.1968],
[ 66.0719],
[ 98.9463],
[107.6888],
[-20.1257],
[-35.4035],
[ 47.1067],
[ nan],
[ 55.6286],
[-18.7892],
[ 66.7059],
[-30.0302],
[ 1.6033],
[112.4620],
[ 74.7656],
[-15.7436],
[ 81.6596],
[-21.8001],
[ 35.0414],
[-23.1017],
[ 40.6284],
[ 68.6394],
[ 34.0150],
[ 34.0285],
[ 78.7965],
[-28.1487],
[ 67.6665],
[-20.2784],
[-21.5674],
[ 72.3873],
[-15.4678],
[ 85.6189],
[ nan],
[-24.8342],
[-29.5941],
[-23.3309],
[101.2122],
[-19.4720],
[-16.1262],
[ -9.3767],
[-23.4962],
[-11.7238],
[ 70.3627],
[-22.0783],
[-20.3157],
[ 67.2404],
[-20.7747],
[112.6491],
[-30.7755],
[-19.6988],
[ 50.0969],
[ 34.5749],
[ 88.2302],
[-21.4236],
[ -8.1834],
[ 73.8116],
[110.0389],
[ nan],
[112.4239],
[-16.1073],
[-32.1473],
[-30.7778],
[ 98.2536],
[ 98.3553],
[ 28.5688],
[107.0244],
[-16.9703],
[-29.7008],
[ 77.2257],
[-28.7558],
[-27.9047],
[ 67.1428],
[-12.7457],
[ 47.7178],
[-29.4141],
[ 59.9424],
[-22.1235],
[129.9478],
[-23.3980],
[-23.6567],
[ 75.9350],
[-10.3434],
[-19.4174],
[ 30.1057],
[ 85.3776],
[ 63.8175],
[ 46.0356],
[-29.5049],
[-22.5860],
[-22.0601],
[113.5105],
[-17.7980],
[ nan],
[-29.2494],
[ 76.9251],
[-32.8878],
[115.3978],
[-24.9450],
[ 35.4470],
[ 95.5165],
[ 86.9620],
[-21.9245],
[-25.5980],
[ 78.7311],
[-22.4280],
[ 5.8148],
[103.3399],
[ 70.4657],
[ 58.4446],
[ 91.0071],
[104.6046],
[ 45.4887],
[-18.9264],
[ 63.5109],
[ 79.4635],
[-10.0517],
[ 76.3909],
[ 34.3802],
[-20.1693],
[-18.6468]], grad_fn=<GatherBackward0>)
last squeeze
tensor([ 59.3737, -10.4617, 67.7197, 94.7701, nan, -29.2101, 67.7193,
-38.7149, -20.1968, 66.0719, 98.9463, 107.6888, -20.1257, -35.4035,
47.1067, nan, 55.6286, -18.7892, 66.7059, -30.0302, 1.6033,
112.4620, 74.7656, -15.7436, 81.6596, -21.8001, 35.0414, -23.1017,
40.6284, 68.6394, 34.0150, 34.0285, 78.7965, -28.1487, 67.6665,
-20.2784, -21.5674, 72.3873, -15.4678, 85.6189, nan, -24.8342,
-29.5941, -23.3309, 101.2122, -19.4720, -16.1262, -9.3767, -23.4962,
-11.7238, 70.3627, -22.0783, -20.3157, 67.2404, -20.7747, 112.6491,
-30.7755, -19.6988, 50.0969, 34.5749, 88.2302, -21.4236, -8.1834,
73.8116, 110.0389, nan, 112.4239, -16.1073, -32.1473, -30.7778,
98.2536, 98.3553, 28.5688, 107.0244, -16.9703, -29.7008, 77.2257,
-28.7558, -27.9047, 67.1428, -12.7457, 47.7178, -29.4141, 59.9424,
-22.1235, 129.9478, -23.3980, -23.6567, 75.9350, -10.3434, -19.4174,
30.1057, 85.3776, 63.8175, 46.0356, -29.5049, -22.5860, -22.0601,
113.5105, -17.7980, nan, -29.2494, 76.9251, -32.8878, 115.3978,
-24.9450, 35.4470, 95.5165, 86.9620, -21.9245, -25.5980, 78.7311,
-22.4280, 5.8148, 103.3399, 70.4657, 58.4446, 91.0071, 104.6046,
45.4887, -18.9264, 63.5109, 79.4635, -10.0517, 76.3909, 34.3802,
-20.1693, -18.6468], grad_fn=<SqueezeBackward1>)
Edit 1: print of actions_v
actions_v
tensor([2, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 2, 0, 2, 0, 1,
2, 0, 2, 1, 1, 0, 2, 1, 0, 0, 2, 1, 1, 1, 0, 1, 0, 1, 1, 2, 1, 1, 2, 1,
0, 2, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1, 2, 0, 0, 1, 1, 2, 0, 0, 2, 0, 0,
1, 1, 2, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 2, 0, 2, 0, 1, 1, 2, 1, 2, 2,
2, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 2, 1, 1, 0, 1, 0, 1, 2,
2, 1, 0, 2, 0, 0, 2, 1])
gather_nd takes inputs that have the same dimension as the input tensor, and will output a tensor of values being at those indices (which is what you want).
gather will output slices (but you can give as indice shape whatever you want, the output tensor will just be a bunch of slices that are structured accordingly to the shape of indices) which is not what you want.
So you should first make the indices match the dimensions of the initial matrix:
indices = tf.transpose(tf.stack((tf.range(tf.shape(state_action_values)[0]),actions_v)))
And then gather_nd:
state_action_values = tf.gather_nd(state_action_values,indices)
Keivan
I run a Pyspark script that performs a map operation on a RDD. The result of such a map is a new RDD containing tuples, i.e. there is one tuple for each partition in the original RDD (there are 10 partitions). Each tuple contains two Numpy ndarray's and the partition ID.
The problem is take(10) and collect() return different values! Even, if a perform first() on the resulting RDD, the tuple is totally different from the first tuple returned by take or collect.
take(10) result is:
[(array([[ 0.19138815, -0.26613894, 0.0148395 , ..., -0.6887879 ,
0.01775263, 0.29900053],
[ 0.46834013, -0.41492677, -0.3986189 , ..., -0.09638319,
-0.27234066, 0.41824088],
[ 0.2235235 , -0.27003226, 0.05322047, ..., -0.3045229 ,
-0.30364496, 0.21981548],
...,
[ 0.40709212, -0.49947056, 0.36821032, ..., -0.27359277,
-0.1552616 , -0.10506155],
[ 0.42315334, -0.6249347 , -0.38093382, ..., -0.52247494,
0.12167282, 0.53337336],
[-0.12733063, -0.27234274, -0.05421005, ..., 0.3884521 ,
0.19977048, 0.11347781]], dtype=float32), array([[ 0.04737013, -0.13501787, -0.04363494, ..., -0.53189725,
-0.1747366 , 0.09700331],
[ 0.59339726, -0.622871 , -0.7482438 , ..., 0.01563211,
-0.24306081, 0.5259429 ],
[ 0.199726 , -0.29667157, 0.03788655, ..., -0.72914207,
-0.61997485, 0.4880442 ],
...,
[ 0.10908518, 0.15787347, 0.13822514, ..., -0.015803 ,
-0.05935803, 0.11415796],
[ 0.02415699, 0.21212797, -0.01504083, ..., -0.0880443 ,
0.03949256, -0.17959005],
[-0.08058456, -0.00894655, -0.10706384, ..., 0.00954069,
-0.18720922, -0.05665499]], dtype=float32), 0), (array([[ 0.53325355, -0.0129254 , 0.10924862, ..., 0.2797827 ,
0.44130138, -0.29074535],
[ 0.29487053, 0.20387554, 0.00834447, ..., 0.3034479 ,
-0.34347925, 0.48914096],
[-0.36053488, 0.2551153 , 0.23102154, ..., -0.08557958,
0.2305064 , -0.11637823],
...,
[ 0.50112355, -0.5487336 , 0.1381122 , ..., -0.17219128,
0.5784589 , -0.39060545],
[ 0.47822553, -0.21500733, -0.02590418, ..., 0.45222896,
-0.29980502, -0.4379743 ],
[-0.27317327, 0.47888 , 0.13328783, ..., 0.45453754,
-0.03382564, 0.28364402]], dtype=float32), array([[ 0.45505902, 0.20775506, 0.38936007, ..., 0.31039652,
0.0612184 , -0.34329525],
[ 0.44160137, 0.45554656, 0.24406506, ..., 0.20645882,
-0.55068386, 0.61949503],
[-0.9611948 , 0.48338398, 0.8936671 , ..., -0.07192033,
-0.04691654, -0.0583482 ],
...,
[ 0.27866694, 0.11791337, 0.00603435, ..., -0.0984261 ,
0.05514587, -0.16367936],
[ 0.24686542, -0.01605012, 0.15803055, ..., 0.14647359,
0.00465332, -0.21551773],
[ 0.01786835, 0.08164094, 0.0458132 , ..., 0.09466646,
-0.07061186, 0.04650302]], dtype=float32), 1), (array([[-0.30373037, 0.4013883 , 0.5544747 , ..., -0.24839583,
-0.434404 , -0.5419062 ],
[ 0.3205092 , -0.21219605, 0.23547144, ..., -0.4373149 ,
0.30616343, -0.45202586],
[-0.25409338, -0.14463192, 0.30881113, ..., -0.29998812,
-0.24947752, -0.18218543],
...,
[-0.33871105, 0.20521186, 0.38351473, ..., 0.33596635,
0.370852 , -0.46504658],
[-0.23807049, 0.1317612 , 0.07848132, ..., -0.32858378,
-0.541284 , -0.4595052 ],
[-0.17040105, 0.48929718, -0.39259502, ..., -0.15026243,
-0.19829535, 0.18581793]], dtype=float32), array([[-0.5295805 , 0.64539313, 0.47219488, ..., -0.4475045 ,
-0.31032106, -0.60634965],
[ 0.28920087, -0.2322041 , 0.313627 , ..., -0.71224356,
0.1539378 , -0.39970958],
[-0.39529595, -0.50494266, 0.61120296, ..., -0.94842255,
-0.7700451 , 0.03852249],
...,
[ 0.082367 , -0.09744105, 0.00574634, ..., 0.04329732,
0.08105459, -0.20032766],
[ 0.04620253, 0.18625231, -0.04047911, ..., -0.1547844 ,
-0.01560262, 0.00372486],
[ 0.0602592 , 0.078048 , 0.04372916, ..., -0.10393928,
-0.27185628, -0.05753115]], dtype=float32), 2), (array([[-0.10130739, -0.5432671 , 0.14230369, ..., -0.20037425,
-0.50981474, -0.39152429],
[-0.10439714, -0.19250502, -0.12469167, ..., 0.50656915,
0.41846293, -0.12511848],
[-0.12075678, 0.13746399, -0.10762265, ..., -0.33095708,
0.38831544, -0.3573719 ],
...,
[-0.33046186, 0.1675668 , -0.22495636, ..., 0.39853546,
-0.21838626, -0.44713587],
[ 0.1753454 , -0.16229314, 0.24015644, ..., -0.09867255,
-0.6676028 , -0.03644819],
[ 0.33300576, -0.28793842, 0.45033735, ..., 0.26602328,
0.18528135, -0.37736982]], dtype=float32), array([[-0.4201235 , -0.42547926, 0.3823224 , ..., -0.14358595,
-0.33159116, -0.47784004],
[-0.29869375, -0.31769726, 0.00777127, ..., 0.5764582 ,
0.160321 , -0.14145109],
[ 0.08045941, 0.31604633, -0.22956751, ..., -0.551948 ,
0.8479948 , -1.065378 ],
...,
[ 0.02872542, -0.23239408, 0.10584452, ..., -0.19653268,
0.01299276, 0.01154038],
[-0.11145593, -0.04056952, 0.17250092, ..., 0.11815882,
-0.09970374, -0.14516412],
[ 0.03227303, 0.08426446, -0.04518193, ..., -0.04804583,
0.11866879, 0.07962584]], dtype=float32), 3), (array([[-0.3442908 , -0.47924808, -0.17909145, ..., 0.57778424,
-0.20964314, -0.38468337],
[-0.1493048 , 0.22245084, -0.19882523, ..., 0.38685095,
-0.5174358 , -0.4666229 ],
[-0.36909524, -0.32142702, -0.183554 , ..., 0.16354944,
-0.31103647, -0.12249314],
...,
[-0.45103338, -0.04599327, 0.3811725 , ..., -0.28528866,
-0.47804165, -0.08458076],
[ 0.2814856 , -0.07282971, -0.6203528 , ..., 0.1307708 ,
0.01159024, -0.28599057],
[-0.20529339, 0.39615756, -0.12310734, ..., 0.21978702,
0.09362441, 0.282171 ]], dtype=float32), array([[-3.9894256e-01, -2.3258871e-01, -3.5374862e-01, ...,
5.5857885e-01, 1.3151944e-02, -3.4470174e-01],
[ 3.7879427e-03, 2.3241083e-01, -1.8348652e-01, ...,
2.8142187e-01, -5.6694925e-01, -7.0757174e-01],
[-1.0855644e+00, -1.0010500e+00, -5.0151312e-01, ...,
1.0360012e+00, -8.4597129e-01, 3.6014456e-01],
...,
[ 9.6085869e-02, 4.7818022e-03, 9.2197888e-02, ...,
5.7999889e-04, -4.3681998e-02, -1.0792557e-01],
[-2.1461031e-01, -7.9450890e-02, -3.6970940e-03, ...,
1.2008668e-01, 8.1481531e-02, -7.4978117e-03],
[ 2.2965204e-02, -7.5380646e-02, -1.5798187e-01, ...,
2.3251076e-01, 2.0147276e-01, -5.9070695e-02]], dtype=float32), 4), (array([[ 0.34773657, 0.1271629 , 0.35084245, ..., -0.4163662 ,
-0.2032509 , 0.16763268],
[ 0.25300634, -0.24963312, 0.03022493, ..., -0.48368084,
-0.5438243 , -0.2523892 ],
[ 0.08074404, 0.34778172, -0.19389851, ..., 0.00394286,
-0.03249731, 0.21850894],
...,
[-0.28799132, 0.46877176, 0.13380927, ..., 0.21133232,
-0.20794198, 0.04732013],
[ 0.533321 , 0.08838256, 0.52776694, ..., -0.3776623 ,
0.3234356 , 0.49971387],
[ 0.01111342, 0.2225688 , 0.5407952 , ..., -0.5621694 ,
-0.48505995, 0.45497927]], dtype=float32), array([[ 0.31677225, 0.3403609 , 0.43394235, ..., -0.22031932,
-0.19851643, 0.07987656],
[ 0.23734477, -0.08880972, 0.29560354, ..., -0.29227307,
-0.6075761 , 0.0549405 ],
[ 0.51728237, 0.92628396, -0.40312177, ..., 0.47099453,
0.1400124 , 0.5190122 ],
...,
[-0.01423888, -0.00644734, -0.01985722, ..., 0.00387737,
0.08140853, -0.04754239],
[-0.20655577, 0.07879162, 0.16483389, ..., 0.03559271,
-0.08346304, 0.02901335],
[ 0.12661955, -0.14012276, 0.13914892, ..., 0.09200166,
-0.01300973, 0.09209644]], dtype=float32), 5), (array([[ 0.6097177 , 0.48289144, 0.107826 , ..., 0.14263372,
0.43684664, 0.37251818],
[-0.47600058, -0.32621023, -0.55021507, ..., -0.48766413,
0.19021659, -0.25467217],
[-0.26290026, 0.19371034, -0.3084416 , ..., -0.38468444,
-0.24624074, -0.1248633 ],
...,
[ 0.29716069, -0.17764345, 0.06396675, ..., -0.51036185,
-0.16205388, -0.11298727],
[ 0.11154262, 0.17699954, -0.55426383, ..., 0.36997363,
0.10982917, -0.16837421],
[-0.0491058 , 0.13139851, 0.38043728, ..., -0.20505017,
0.15700234, 0.5629593 ]], dtype=float32), array([[ 5.4628938e-01, 4.5117161e-01, 3.2670802e-01, ...,
1.9686517e-01, 3.2789686e-01, 4.5343029e-01],
[-6.4986098e-01, -2.4272519e-01, -7.6540470e-01, ...,
-7.3495942e-01, 2.8457634e-02, -1.3791913e-01],
[-7.1882963e-01, 6.7012483e-01, -6.4978844e-01, ...,
-6.2161821e-01, -6.2733275e-01, -3.5520321e-01],
...,
[-9.8777540e-02, -3.5368185e-02, 3.7035108e-02, ...,
3.2906037e-02, -3.7386462e-02, -7.3537357e-02],
[ 7.1388625e-02, 1.7618576e-01, -3.5727687e-02, ...,
-1.8145926e-01, -8.1750348e-02, 1.6553156e-02],
[-4.8265442e-02, -1.0864433e-01, -7.2953522e-02, ...,
5.3536858e-02, 1.4538332e-04, 1.6074999e-01]], dtype=float32), 6), (array([[-0.60614747, 0.12892431, -0.3011826 , ..., 0.25248086,
-0.52810246, 0.14257856],
[-0.42474708, -0.37544054, -0.3886031 , ..., -0.35798463,
-0.18877436, 0.3048291 ],
[-0.17884058, -0.37839454, 0.3589297 , ..., 0.05497296,
-0.06037642, -0.4278129 ],
...,
[ 0.1893682 , 0.22125317, -0.20827802, ..., -0.29234493,
-0.1302274 , 0.02801383],
[ 0.1747711 , -0.0879383 , -0.395539 , ..., -0.38479805,
-0.61469847, 0.00207329],
[ 0.3819209 , 0.49023125, -0.42264247, ..., 0.1528128 ,
0.45578462, 0.4668125 ]], dtype=float32), array([[-5.28453410e-01, 1.38694681e-02, -2.60350943e-01, ...,
2.34538317e-01, -4.43873584e-01, 1.47595644e-01],
[-5.25614679e-01, -4.94065553e-01, -4.03498739e-01, ...,
-4.45649236e-01, -1.77955553e-01, 3.10622215e-01],
[ 1.01587474e-01, -8.57380509e-01, 8.03865731e-01, ...,
2.06392854e-01, -1.75320432e-01, -7.59775221e-01],
...,
[ 8.77973512e-02, -2.19333827e-01, -1.42306268e-01, ...,
-1.03031479e-01, 6.18839522e-06, 1.56334475e-01],
[-1.23811606e-02, -3.57912146e-02, -7.11395666e-02, ...,
4.88635562e-02, 1.20033674e-01, -2.04370469e-02],
[ 1.00751802e-01, -4.78806317e-01, 1.95426181e-01, ...,
-5.01195967e-01, -3.37407261e-01, 3.57391506e-01]], dtype=float32), 7), (array([[ 0.49053288, -0.04683368, -0.0879433 , ..., 0.6166893 ,
-0.5472508 , 0.5924473 ],
[-0.40197268, 0.21878959, 0.47748646, ..., -0.27519724,
0.3854015 , -0.04976773],
[ 0.24380569, -0.04194092, -0.22590604, ..., 0.35376453,
0.2546404 , -0.20142618],
...,
[-0.5142191 , -0.2877738 , -0.47166097, ..., 0.48306477,
0.26082426, -0.4445646 ],
[-0.25510445, 0.00508366, -0.5078671 , ..., -0.27604827,
-0.08479042, 0.04850767],
[-0.37874135, 0.49107817, 0.11259978, ..., 0.3188926 ,
0.15944068, 0.0829725 ]], dtype=float32), array([[ 0.5134689 , -0.18948251, -0.05751067, ..., 0.48557195,
-0.3173096 , 0.372445 ],
[-0.43809155, -0.08225048, 0.7466321 , ..., -0.5873075 ,
0.7402181 , -0.09714872],
[ 0.68586934, 0.1767921 , 0.07150997, ..., 0.8427489 ,
0.45624557, -0.32352042],
...,
[ 0.02128473, -0.00579748, 0.12406835, ..., 0.08425785,
0.21446604, 0.09736192],
[ 0.04778935, -0.12865652, 0.00840133, ..., -0.00293668,
0.06652898, -0.03904634],
[ 0.02595548, 0.11789754, -0.02631662, ..., 0.03307972,
0.24783 , -0.12637296]], dtype=float32), 8), (array([[-0.09508003, 0.5282958 , 0.04243381, ..., -0.0648976 ,
-0.56710106, -0.06858341],
[-0.23444827, -0.17710732, -0.29838824, ..., -0.03992304,
-0.51030684, -0.34101528],
[ 0.00919831, 0.00875767, -0.09158213, ..., 0.26422626,
-0.19114144, 0.2717857 ],
...,
[-0.3377122 , 0.49312168, -0.50667596, ..., -0.33557674,
-0.3865259 , 0.4990052 ],
[ 0.17765625, -0.06699899, 0.29469523, ..., 0.612583 ,
0.13147196, -0.27174017],
[-0.56627136, -0.25801656, -0.28928643, ..., 0.18859185,
-0.46310693, 0.23317206]], dtype=float32), array([[ 0.22075664, 0.41444212, -0.4506032 , ..., -0.16359589,
-0.569371 , -0.13674185],
[-0.19642216, -0.36554584, -0.45241717, ..., -0.06579936,
-0.656618 , -0.11742882],
[-0.04687649, 0.10208729, -0.060711 , ..., 0.78385997,
-0.17697819, 0.75614196],
...,
[ 0.2150026 , 0.15821525, -0.16107246, ..., -0.20459318,
-0.20312549, 0.11470713],
[ 0.06135871, 0.06140292, -0.11083869, ..., 0.14165913,
0.05527819, 0.0520999 ],
[ 0.13505958, 0.06166949, -0.11568853, ..., -0.04749477,
0.02072733, -0.05888995]], dtype=float32), 9)]
collect() result is:
[(array([[ 0.18516347, -0.33975708, -0.46829244, ..., 0.498327 ,
-0.1628269 , 0.5171599 ],
[ 0.24843855, 0.43924475, 0.43121427, ..., -0.3605212 ,
-0.2543247 , -0.35761902],
[ 0.03349265, 0.28567392, -0.3129074 , ..., 0.30228034,
0.33539015, 0.28145155],
...,
[ 0.39538023, -0.11668223, 0.23590142, ..., -0.39222914,
-0.34792763, 0.43729994],
[-0.37299404, -0.40583754, -0.41405225, ..., 0.3708834 ,
0.6067088 , 0.5815965 ],
[-0.5297639 , 0.09037948, 0.06255247, ..., 0.55813074,
0.2599809 , 0.2930913 ]], dtype=float32), array([[-0.05441456, -0.08513332, -0.47554415, ..., 0.46570715,
0.11365455, 0.641596 ],
[ 0.25915185, 0.6784206 , 0.5428535 , ..., -0.3223893 ,
-0.17784661, -0.3021973 ],
[ 0.4007858 , 0.48166505, -0.7551351 , ..., 0.5893394 ,
0.5379706 , 0.5853663 ],
...,
[ 0.03785443, -0.02449032, 0.07482295, ..., 0.14570121,
0.02578176, 0.11021709],
[ 0.02157725, 0.20236807, -0.25889152, ..., -0.2123813 ,
0.11124409, 0.05835798],
[-0.2482176 , 0.1100187 , -0.16054511, ..., 0.15483692,
-0.01258843, -0.000899 ]], dtype=float32), 0), (array([[-0.06023056, 0.34326535, 0.01615176, ..., 0.50180113,
0.35740197, -0.3607464 ],
[-0.37374282, -0.05733229, -0.10494906, ..., 0.10067802,
-0.30181995, 0.19373518],
[-0.21093842, -0.35539758, 0.2722222 , ..., -0.13212524,
0.15457118, 0.29343936],
...,
[-0.3844776 , 0.29577827, 0.23207994, ..., -0.2748728 ,
0.05118364, -0.43278512],
[ 0.18988602, 0.15946351, -0.37208527, ..., 0.18980268,
0.26914784, 0.57002 ],
[-0.19685334, -0.00215623, -0.50676346, ..., -0.25601804,
0.43306062, -0.45977998]], dtype=float32), array([[ 0.07212897, 0.37888303, 0.14216022, ..., 0.3568563 ,
0.31809187, -0.30161127],
[-0.14504915, 0.09403719, -0.24099208, ..., 0.1194509 ,
-0.571604 , 0.3073006 ],
[-0.34506813, -0.5373435 , 0.39612344, ..., -0.17277275,
-0.15978633, 0.9480076 ],
...,
[ 0.33411705, 0.3616483 , -0.00284358, ..., -0.1570069 ,
-0.10693 , 0.11339971],
[ 0.03463184, 0.02923653, -0.07571009, ..., -0.01076985,
-0.24661644, 0.07400434],
[ 0.03500444, 0.02041529, 0.11345199, ..., -0.21085429,
-0.11910766, -0.16162656]], dtype=float32), 1), (array([[ 0.40148696, 0.44860718, -0.13891219, ..., 0.38478658,
0.0354558 , 0.5507143 ],
[ 0.35727167, 0.33523828, -0.40109852, ..., -0.12698223,
-0.48993534, 0.4485451 ],
[-0.37534824, -0.07898732, -0.28266475, ..., 0.0923453 ,
0.20398574, 0.46967202],
...,
[-0.31968838, 0.47479796, -0.49929148, ..., -0.23865293,
-0.24262336, -0.06511039],
[ 0.2098573 , -0.5782443 , 0.0044039 , ..., -0.13356705,
-0.5997722 , 0.24789433],
[ 0.18989192, -0.41790476, -0.5493083 , ..., -0.04386537,
0.14099114, -0.3851897 ]], dtype=float32), array([[ 0.35465854, 0.5298055 , 0.09639163, ..., 0.42402148,
0.1006598 , 0.42896628],
[ 0.49630994, 0.53968257, -0.45613742, ..., 0.11758496,
-0.82799774, 0.47129843],
[-0.11091773, -0.126068 , -0.94792765, ..., -0.39254257,
0.49629924, 0.90804875],
...,
[ 0.09986763, 0.18694244, 0.04551417, ..., -0.00185746,
0.04787954, 0.14079888],
[ 0.01976901, 0.01671817, 0.02434383, ..., -0.05640491,
0.03537085, -0.08094196],
[-0.12894829, 0.15826981, -0.09516723, ..., -0.11121278,
-0.17831786, -0.00805143]], dtype=float32), 2), (array([[-0.16927287, -0.3801918 , -0.32327962, ..., -0.51245123,
0.41986853, 0.18242987],
[-0.41638187, -0.06312063, -0.40284333, ..., 0.26918623,
-0.4305522 , -0.4801858 ],
[ 0.00497885, -0.22712015, -0.35257223, ..., 0.02938372,
0.32673585, 0.176891 ],
...,
[ 0.22629777, 0.39141867, 0.3272797 , ..., -0.45520803,
0.17408061, -0.27852598],
[-0.24445221, -0.35762975, -0.39768136, ..., 0.26196685,
0.17221238, 0.22423406],
[ 0.38731757, -0.45889175, 0.3848555 , ..., 0.469341 ,
0.2884723 , 0.4584588 ]], dtype=float32), array([[-0.38282454, -0.33555806, -0.38107285, ..., -0.4294134 ,
0.40941086, 0.19802842],
[-0.5812499 , -0.08969598, -0.5220119 , ..., 0.35569692,
-0.436198 , -0.78949076],
[-0.6598932 , -0.61103916, -0.8173852 , ..., 0.39758506,
0.77731544, 0.38008815],
...,
[ 0.16403008, -0.06079637, -0.02834259, ..., -0.10361861,
0.10504336, 0.21878755],
[-0.04492131, -0.17708012, -0.2025333 , ..., 0.17902693,
-0.15750957, -0.23726523],
[-0.07494884, -0.06165536, -0.10450385, ..., 0.17355536,
0.01258762, -0.05638846]], dtype=float32), 3), (array([[ 0.00608851, -0.5132972 , -0.36956814, ..., -0.5321781 ,
-0.26002124, 0.18518737],
[-0.21294856, -0.30260456, 0.11435414, ..., -0.52438897,
-0.46490332, -0.43702644],
[ 0.32582685, -0.25917143, -0.34966806, ..., -0.29239386,
-0.09908075, 0.32888958],
...,
[-0.25709772, 0.32678795, 0.00447032, ..., -0.04280795,
0.01051257, 0.36575 ],
[-0.07194624, -0.245111 , 0.5667958 , ..., 0.41409835,
-0.5799336 , -0.4858021 ],
[-0.11811656, 0.52633476, 0.40527648, ..., 0.40293694,
-0.4926284 , -0.14098077]], dtype=float32), array([[-0.11553568, -0.3299477 , -0.50079507, ..., -0.44186658,
0.01149882, 0.01778618],
[-0.3005174 , -0.26895422, 0.08186761, ..., -0.7348857 ,
-0.5810445 , -0.53310806],
[ 0.5055336 , -0.35796392, -0.90873146, ..., -0.78193223,
-0.04517283, 0.7075767 ],
...,
[-0.02022595, 0.01297146, -0.02664614, ..., 0.10258275,
-0.06243805, -0.07688553],
[-0.24571168, 0.18370496, 0.03886681, ..., -0.03063186,
-0.04676892, -0.10450852],
[ 0.03074093, -0.045911 , 0.07248624, ..., -0.05876327,
0.06366935, -0.01161662]], dtype=float32), 4), (array([[-0.28593373, 0.13395616, 0.48233178, ..., 0.508933 ,
-0.19197209, 0.3298264 ],
[-0.04904724, 0.34900156, 0.32834592, ..., -0.42706388,
-0.39813402, 0.14217453],
[ 0.13526852, 0.3745679 , 0.12265893, ..., 0.30098978,
0.15158501, 0.2164896 ],
...,
[ 0.21077481, 0.31915814, -0.32008937, ..., -0.48824465,
0.15033281, -0.55469203],
[-0.33247608, -0.05625575, 0.43155706, ..., 0.34942424,
0.04699863, 0.17167164],
[-0.32446697, 0.3883351 , -0.18434255, ..., -0.481489 ,
0.38606554, -0.31928998]], dtype=float32), array([[-0.21347475, 0.1707647 , 0.40377334, ..., 0.21519852,
-0.2119801 , 0.19800745],
[ 0.11622372, 0.4918646 , 0.4055998 , ..., -0.65552425,
-0.5142339 , 0.01027841],
[ 0.5708596 , 0.58031136, 0.5869403 , ..., 0.7259039 ,
0.10178877, 0.09589735],
...,
[-0.04035015, 0.08506791, 0.00343785, ..., 0.06102315,
0.07513183, -0.05011833],
[ 0.15481526, 0.14573523, -0.04516461, ..., -0.14253813,
-0.0463484 , -0.2259047 ],
[ 0.22652309, 0.2351952 , -0.03388155, ..., -0.04040325,
-0.17493977, -0.28690276]], dtype=float32), 5), (array([[ 0.01241684, 0.3848157 , -0.39109015, ..., 0.5478949 ,
-0.56430525, 0.6666779 ],
[ 0.14730261, -0.01858082, -0.3315102 , ..., 0.2657176 ,
-0.33017242, 0.21192974],
[ 0.33523917, -0.18178591, -0.3066008 , ..., 0.20106438,
-0.22945583, -0.03762559],
...,
[-0.2540486 , 0.18161409, -0.18843989, ..., -0.2343464 ,
-0.09964091, -0.06174641],
[-0.04365243, 0.5485354 , 0.20301312, ..., 0.46485335,
-0.5898763 , -0.32282397],
[-0.48032835, -0.44875026, 0.27545917, ..., -0.38980302,
-0.42448705, -0.28787732]], dtype=float32), array([[ 0.12385425, 0.34021133, -0.5947283 , ..., 0.5262172 ,
-0.50784826, 0.6099582 ],
[ 0.36343968, -0.00910554, -0.4031666 , ..., 0.29984346,
-0.51017505, 0.19207458],
[ 0.7371366 , -0.05033325, -0.65316683, ..., 0.6836462 ,
-0.82514614, 0.05400744],
...,
[ 0.00850144, 0.14323875, -0.17175199, ..., 0.02427784,
0.08420605, 0.08254603],
[-0.05369973, 0.1430688 , -0.18583198, ..., 0.22647178,
-0.23855914, 0.11469093],
[-0.16125312, 0.00549565, -0.1524472 , ..., 0.02182539,
-0.07359254, 0.13050811]], dtype=float32), 6), (array([[-0.04644716, 0.05705595, -0.24267559, ..., -0.21403529,
-0.06703684, 0.41887376],
[-0.29370484, -0.39780775, -0.3568661 , ..., -0.09881599,
0.07795003, 0.38119403],
[-0.3645612 , 0.10963462, -0.347853 , ..., -0.27757406,
-0.27060333, 0.3227043 ],
...,
[ 0.29615575, -0.4612898 , 0.05339388, ..., -0.08740558,
-0.25723857, -0.49486127],
[ 0.05623814, 0.4080824 , -0.24716274, ..., -0.09058192,
0.1756261 , -0.3089906 ],
[-0.5414306 , 0.4408762 , 0.42941254, ..., 0.15542302,
0.5396825 , -0.2709453 ]], dtype=float32), array([[-3.42568368e-01, 1.64543867e-01, 1.99411588e-04, ...,
-2.34075099e-01, 4.41714190e-02, 3.03957969e-01],
[-4.59636778e-01, -4.12208438e-01, -4.32887614e-01, ...,
-5.51471636e-02, 2.10178539e-01, 5.81714928e-01],
[-1.12530768e+00, 6.36167228e-01, -6.05338633e-01, ...,
-5.48392713e-01, -2.09195793e-01, 1.01132858e+00],
...,
[-1.30258754e-01, -4.16335762e-02, 1.45700663e-01, ...,
-3.17536071e-02, 3.02174967e-02, 1.10822864e-01],
[-1.39803439e-01, 3.05877943e-02, 1.05636232e-01, ...,
1.10526226e-01, 6.08176775e-02, -5.73274679e-02],
[-2.55209617e-02, 1.39828652e-01, -5.54599836e-02, ...,
1.40946835e-01, 1.95314527e-01, -7.23765837e-03]], dtype=float32), 7), (array([[-0.07237607, -0.06369619, -0.57799906, ..., -0.3011678 ,
-0.3869047 , -0.5708126 ],
[-0.00725966, -0.04352329, -0.14471681, ..., -0.47405225,
-0.11870398, 0.44799381],
[ 0.36965963, 0.19295754, -0.25880384, ..., -0.27418908,
-0.2637073 , -0.25275636],
...,
[ 0.06180732, 0.10883695, -0.20686714, ..., -0.4045689 ,
-0.10775824, -0.00597983],
[ 0.17705184, 0.2853461 , 0.38804924, ..., 0.00480051,
0.23195331, 0.5900061 ],
[ 0.33405438, -0.05846346, -0.49157378, ..., 0.13280089,
0.4277615 , -0.43489072]], dtype=float32), array([[-0.02955375, 0.03905029, -0.5314531 , ..., -0.3317717 ,
-0.32801694, -0.5832374 ],
[-0.15554425, 0.01338848, -0.30303237, ..., -0.5683995 ,
0.07281558, 0.39870417],
[ 0.57431275, 0.48248613, -0.36046496, ..., -0.740754 ,
-0.4101879 , -0.46861094],
...,
[-0.05291473, -0.00626671, -0.16552408, ..., -0.21676618,
0.00198667, 0.15837853],
[-0.20605578, 0.02732588, 0.05644984, ..., -0.14183098,
0.11027621, 0.00878341],
[ 0.01149986, 0.02788752, -0.07346534, ..., 0.02313588,
0.0365326 , 0.0427165 ]], dtype=float32), 8), (array([[ 0.32586667, 0.6187046 , 0.3162743 , ..., 0.42238498,
-0.08919488, -0.23198491],
[ 0.1516603 , 0.04449864, 0.10896807, ..., -0.49212578,
0.14955266, -0.04248938],
[-0.23296818, -0.16775374, 0.41552317, ..., -0.27849862,
0.08736038, 0.36777073],
...,
[ 0.17306753, 0.4053642 , -0.063707 , ..., 0.09530412,
-0.46045092, 0.42887986],
[-0.63025045, 0.3146556 , 0.148895 , ..., 0.48645812,
0.27349007, -0.0574788 ],
[-0.00731906, -0.03560491, -0.3711313 , ..., 0.2597622 ,
0.44751585, 0.30753264]], dtype=float32), array([[ 0.48764184, 0.54429394, 0.461604 , ..., 0.5304342 ,
0.20013283, -0.23899378],
[ 0.15820505, -0.131124 , -0.02082682, ..., -0.58288604,
0.04917484, -0.06569459],
[-0.74465793, -0.29858342, 1.2702694 , ..., -0.80466145,
0.45967343, 0.913658 ],
...,
[-0.02027037, -0.06839398, -0.14130257, ..., -0.14568064,
-0.02432712, 0.02081295],
[-0.03228899, -0.05201127, 0.1773543 , ..., 0.05981961,
-0.13131005, 0.00199337],
[-0.14305267, 0.02065172, 0.01255422, ..., 0.02369817,
0.0101981 , 0.13320476]], dtype=float32), 9)]
Any idea about what this weird behaviour? Maybe something related to byte order in Numpy ndarray's?
My fault. I forgot about Spark lazy evaluation (map() is a transformation, not an action as take() and collect() are; tranformations are not actually evaluated until an action is called).
This, combined with the fact my mapping function seems to be not idempotent (i.e. it does not return the same value upon executions) leads to following behaviour: calling take(10) actually evaluates the map() for a first time (returning a value), and calling collect() actually evaluates the map() for a second time (and returning a totally different value).