I run a Pyspark script that performs a map operation on a RDD. The result of such a map is a new RDD containing tuples, i.e. there is one tuple for each partition in the original RDD (there are 10 partitions). Each tuple contains two Numpy ndarray's and the partition ID.
The problem is take(10) and collect() return different values! Even, if a perform first() on the resulting RDD, the tuple is totally different from the first tuple returned by take or collect.
take(10) result is:
[(array([[ 0.19138815, -0.26613894, 0.0148395 , ..., -0.6887879 ,
0.01775263, 0.29900053],
[ 0.46834013, -0.41492677, -0.3986189 , ..., -0.09638319,
-0.27234066, 0.41824088],
[ 0.2235235 , -0.27003226, 0.05322047, ..., -0.3045229 ,
-0.30364496, 0.21981548],
...,
[ 0.40709212, -0.49947056, 0.36821032, ..., -0.27359277,
-0.1552616 , -0.10506155],
[ 0.42315334, -0.6249347 , -0.38093382, ..., -0.52247494,
0.12167282, 0.53337336],
[-0.12733063, -0.27234274, -0.05421005, ..., 0.3884521 ,
0.19977048, 0.11347781]], dtype=float32), array([[ 0.04737013, -0.13501787, -0.04363494, ..., -0.53189725,
-0.1747366 , 0.09700331],
[ 0.59339726, -0.622871 , -0.7482438 , ..., 0.01563211,
-0.24306081, 0.5259429 ],
[ 0.199726 , -0.29667157, 0.03788655, ..., -0.72914207,
-0.61997485, 0.4880442 ],
...,
[ 0.10908518, 0.15787347, 0.13822514, ..., -0.015803 ,
-0.05935803, 0.11415796],
[ 0.02415699, 0.21212797, -0.01504083, ..., -0.0880443 ,
0.03949256, -0.17959005],
[-0.08058456, -0.00894655, -0.10706384, ..., 0.00954069,
-0.18720922, -0.05665499]], dtype=float32), 0), (array([[ 0.53325355, -0.0129254 , 0.10924862, ..., 0.2797827 ,
0.44130138, -0.29074535],
[ 0.29487053, 0.20387554, 0.00834447, ..., 0.3034479 ,
-0.34347925, 0.48914096],
[-0.36053488, 0.2551153 , 0.23102154, ..., -0.08557958,
0.2305064 , -0.11637823],
...,
[ 0.50112355, -0.5487336 , 0.1381122 , ..., -0.17219128,
0.5784589 , -0.39060545],
[ 0.47822553, -0.21500733, -0.02590418, ..., 0.45222896,
-0.29980502, -0.4379743 ],
[-0.27317327, 0.47888 , 0.13328783, ..., 0.45453754,
-0.03382564, 0.28364402]], dtype=float32), array([[ 0.45505902, 0.20775506, 0.38936007, ..., 0.31039652,
0.0612184 , -0.34329525],
[ 0.44160137, 0.45554656, 0.24406506, ..., 0.20645882,
-0.55068386, 0.61949503],
[-0.9611948 , 0.48338398, 0.8936671 , ..., -0.07192033,
-0.04691654, -0.0583482 ],
...,
[ 0.27866694, 0.11791337, 0.00603435, ..., -0.0984261 ,
0.05514587, -0.16367936],
[ 0.24686542, -0.01605012, 0.15803055, ..., 0.14647359,
0.00465332, -0.21551773],
[ 0.01786835, 0.08164094, 0.0458132 , ..., 0.09466646,
-0.07061186, 0.04650302]], dtype=float32), 1), (array([[-0.30373037, 0.4013883 , 0.5544747 , ..., -0.24839583,
-0.434404 , -0.5419062 ],
[ 0.3205092 , -0.21219605, 0.23547144, ..., -0.4373149 ,
0.30616343, -0.45202586],
[-0.25409338, -0.14463192, 0.30881113, ..., -0.29998812,
-0.24947752, -0.18218543],
...,
[-0.33871105, 0.20521186, 0.38351473, ..., 0.33596635,
0.370852 , -0.46504658],
[-0.23807049, 0.1317612 , 0.07848132, ..., -0.32858378,
-0.541284 , -0.4595052 ],
[-0.17040105, 0.48929718, -0.39259502, ..., -0.15026243,
-0.19829535, 0.18581793]], dtype=float32), array([[-0.5295805 , 0.64539313, 0.47219488, ..., -0.4475045 ,
-0.31032106, -0.60634965],
[ 0.28920087, -0.2322041 , 0.313627 , ..., -0.71224356,
0.1539378 , -0.39970958],
[-0.39529595, -0.50494266, 0.61120296, ..., -0.94842255,
-0.7700451 , 0.03852249],
...,
[ 0.082367 , -0.09744105, 0.00574634, ..., 0.04329732,
0.08105459, -0.20032766],
[ 0.04620253, 0.18625231, -0.04047911, ..., -0.1547844 ,
-0.01560262, 0.00372486],
[ 0.0602592 , 0.078048 , 0.04372916, ..., -0.10393928,
-0.27185628, -0.05753115]], dtype=float32), 2), (array([[-0.10130739, -0.5432671 , 0.14230369, ..., -0.20037425,
-0.50981474, -0.39152429],
[-0.10439714, -0.19250502, -0.12469167, ..., 0.50656915,
0.41846293, -0.12511848],
[-0.12075678, 0.13746399, -0.10762265, ..., -0.33095708,
0.38831544, -0.3573719 ],
...,
[-0.33046186, 0.1675668 , -0.22495636, ..., 0.39853546,
-0.21838626, -0.44713587],
[ 0.1753454 , -0.16229314, 0.24015644, ..., -0.09867255,
-0.6676028 , -0.03644819],
[ 0.33300576, -0.28793842, 0.45033735, ..., 0.26602328,
0.18528135, -0.37736982]], dtype=float32), array([[-0.4201235 , -0.42547926, 0.3823224 , ..., -0.14358595,
-0.33159116, -0.47784004],
[-0.29869375, -0.31769726, 0.00777127, ..., 0.5764582 ,
0.160321 , -0.14145109],
[ 0.08045941, 0.31604633, -0.22956751, ..., -0.551948 ,
0.8479948 , -1.065378 ],
...,
[ 0.02872542, -0.23239408, 0.10584452, ..., -0.19653268,
0.01299276, 0.01154038],
[-0.11145593, -0.04056952, 0.17250092, ..., 0.11815882,
-0.09970374, -0.14516412],
[ 0.03227303, 0.08426446, -0.04518193, ..., -0.04804583,
0.11866879, 0.07962584]], dtype=float32), 3), (array([[-0.3442908 , -0.47924808, -0.17909145, ..., 0.57778424,
-0.20964314, -0.38468337],
[-0.1493048 , 0.22245084, -0.19882523, ..., 0.38685095,
-0.5174358 , -0.4666229 ],
[-0.36909524, -0.32142702, -0.183554 , ..., 0.16354944,
-0.31103647, -0.12249314],
...,
[-0.45103338, -0.04599327, 0.3811725 , ..., -0.28528866,
-0.47804165, -0.08458076],
[ 0.2814856 , -0.07282971, -0.6203528 , ..., 0.1307708 ,
0.01159024, -0.28599057],
[-0.20529339, 0.39615756, -0.12310734, ..., 0.21978702,
0.09362441, 0.282171 ]], dtype=float32), array([[-3.9894256e-01, -2.3258871e-01, -3.5374862e-01, ...,
5.5857885e-01, 1.3151944e-02, -3.4470174e-01],
[ 3.7879427e-03, 2.3241083e-01, -1.8348652e-01, ...,
2.8142187e-01, -5.6694925e-01, -7.0757174e-01],
[-1.0855644e+00, -1.0010500e+00, -5.0151312e-01, ...,
1.0360012e+00, -8.4597129e-01, 3.6014456e-01],
...,
[ 9.6085869e-02, 4.7818022e-03, 9.2197888e-02, ...,
5.7999889e-04, -4.3681998e-02, -1.0792557e-01],
[-2.1461031e-01, -7.9450890e-02, -3.6970940e-03, ...,
1.2008668e-01, 8.1481531e-02, -7.4978117e-03],
[ 2.2965204e-02, -7.5380646e-02, -1.5798187e-01, ...,
2.3251076e-01, 2.0147276e-01, -5.9070695e-02]], dtype=float32), 4), (array([[ 0.34773657, 0.1271629 , 0.35084245, ..., -0.4163662 ,
-0.2032509 , 0.16763268],
[ 0.25300634, -0.24963312, 0.03022493, ..., -0.48368084,
-0.5438243 , -0.2523892 ],
[ 0.08074404, 0.34778172, -0.19389851, ..., 0.00394286,
-0.03249731, 0.21850894],
...,
[-0.28799132, 0.46877176, 0.13380927, ..., 0.21133232,
-0.20794198, 0.04732013],
[ 0.533321 , 0.08838256, 0.52776694, ..., -0.3776623 ,
0.3234356 , 0.49971387],
[ 0.01111342, 0.2225688 , 0.5407952 , ..., -0.5621694 ,
-0.48505995, 0.45497927]], dtype=float32), array([[ 0.31677225, 0.3403609 , 0.43394235, ..., -0.22031932,
-0.19851643, 0.07987656],
[ 0.23734477, -0.08880972, 0.29560354, ..., -0.29227307,
-0.6075761 , 0.0549405 ],
[ 0.51728237, 0.92628396, -0.40312177, ..., 0.47099453,
0.1400124 , 0.5190122 ],
...,
[-0.01423888, -0.00644734, -0.01985722, ..., 0.00387737,
0.08140853, -0.04754239],
[-0.20655577, 0.07879162, 0.16483389, ..., 0.03559271,
-0.08346304, 0.02901335],
[ 0.12661955, -0.14012276, 0.13914892, ..., 0.09200166,
-0.01300973, 0.09209644]], dtype=float32), 5), (array([[ 0.6097177 , 0.48289144, 0.107826 , ..., 0.14263372,
0.43684664, 0.37251818],
[-0.47600058, -0.32621023, -0.55021507, ..., -0.48766413,
0.19021659, -0.25467217],
[-0.26290026, 0.19371034, -0.3084416 , ..., -0.38468444,
-0.24624074, -0.1248633 ],
...,
[ 0.29716069, -0.17764345, 0.06396675, ..., -0.51036185,
-0.16205388, -0.11298727],
[ 0.11154262, 0.17699954, -0.55426383, ..., 0.36997363,
0.10982917, -0.16837421],
[-0.0491058 , 0.13139851, 0.38043728, ..., -0.20505017,
0.15700234, 0.5629593 ]], dtype=float32), array([[ 5.4628938e-01, 4.5117161e-01, 3.2670802e-01, ...,
1.9686517e-01, 3.2789686e-01, 4.5343029e-01],
[-6.4986098e-01, -2.4272519e-01, -7.6540470e-01, ...,
-7.3495942e-01, 2.8457634e-02, -1.3791913e-01],
[-7.1882963e-01, 6.7012483e-01, -6.4978844e-01, ...,
-6.2161821e-01, -6.2733275e-01, -3.5520321e-01],
...,
[-9.8777540e-02, -3.5368185e-02, 3.7035108e-02, ...,
3.2906037e-02, -3.7386462e-02, -7.3537357e-02],
[ 7.1388625e-02, 1.7618576e-01, -3.5727687e-02, ...,
-1.8145926e-01, -8.1750348e-02, 1.6553156e-02],
[-4.8265442e-02, -1.0864433e-01, -7.2953522e-02, ...,
5.3536858e-02, 1.4538332e-04, 1.6074999e-01]], dtype=float32), 6), (array([[-0.60614747, 0.12892431, -0.3011826 , ..., 0.25248086,
-0.52810246, 0.14257856],
[-0.42474708, -0.37544054, -0.3886031 , ..., -0.35798463,
-0.18877436, 0.3048291 ],
[-0.17884058, -0.37839454, 0.3589297 , ..., 0.05497296,
-0.06037642, -0.4278129 ],
...,
[ 0.1893682 , 0.22125317, -0.20827802, ..., -0.29234493,
-0.1302274 , 0.02801383],
[ 0.1747711 , -0.0879383 , -0.395539 , ..., -0.38479805,
-0.61469847, 0.00207329],
[ 0.3819209 , 0.49023125, -0.42264247, ..., 0.1528128 ,
0.45578462, 0.4668125 ]], dtype=float32), array([[-5.28453410e-01, 1.38694681e-02, -2.60350943e-01, ...,
2.34538317e-01, -4.43873584e-01, 1.47595644e-01],
[-5.25614679e-01, -4.94065553e-01, -4.03498739e-01, ...,
-4.45649236e-01, -1.77955553e-01, 3.10622215e-01],
[ 1.01587474e-01, -8.57380509e-01, 8.03865731e-01, ...,
2.06392854e-01, -1.75320432e-01, -7.59775221e-01],
...,
[ 8.77973512e-02, -2.19333827e-01, -1.42306268e-01, ...,
-1.03031479e-01, 6.18839522e-06, 1.56334475e-01],
[-1.23811606e-02, -3.57912146e-02, -7.11395666e-02, ...,
4.88635562e-02, 1.20033674e-01, -2.04370469e-02],
[ 1.00751802e-01, -4.78806317e-01, 1.95426181e-01, ...,
-5.01195967e-01, -3.37407261e-01, 3.57391506e-01]], dtype=float32), 7), (array([[ 0.49053288, -0.04683368, -0.0879433 , ..., 0.6166893 ,
-0.5472508 , 0.5924473 ],
[-0.40197268, 0.21878959, 0.47748646, ..., -0.27519724,
0.3854015 , -0.04976773],
[ 0.24380569, -0.04194092, -0.22590604, ..., 0.35376453,
0.2546404 , -0.20142618],
...,
[-0.5142191 , -0.2877738 , -0.47166097, ..., 0.48306477,
0.26082426, -0.4445646 ],
[-0.25510445, 0.00508366, -0.5078671 , ..., -0.27604827,
-0.08479042, 0.04850767],
[-0.37874135, 0.49107817, 0.11259978, ..., 0.3188926 ,
0.15944068, 0.0829725 ]], dtype=float32), array([[ 0.5134689 , -0.18948251, -0.05751067, ..., 0.48557195,
-0.3173096 , 0.372445 ],
[-0.43809155, -0.08225048, 0.7466321 , ..., -0.5873075 ,
0.7402181 , -0.09714872],
[ 0.68586934, 0.1767921 , 0.07150997, ..., 0.8427489 ,
0.45624557, -0.32352042],
...,
[ 0.02128473, -0.00579748, 0.12406835, ..., 0.08425785,
0.21446604, 0.09736192],
[ 0.04778935, -0.12865652, 0.00840133, ..., -0.00293668,
0.06652898, -0.03904634],
[ 0.02595548, 0.11789754, -0.02631662, ..., 0.03307972,
0.24783 , -0.12637296]], dtype=float32), 8), (array([[-0.09508003, 0.5282958 , 0.04243381, ..., -0.0648976 ,
-0.56710106, -0.06858341],
[-0.23444827, -0.17710732, -0.29838824, ..., -0.03992304,
-0.51030684, -0.34101528],
[ 0.00919831, 0.00875767, -0.09158213, ..., 0.26422626,
-0.19114144, 0.2717857 ],
...,
[-0.3377122 , 0.49312168, -0.50667596, ..., -0.33557674,
-0.3865259 , 0.4990052 ],
[ 0.17765625, -0.06699899, 0.29469523, ..., 0.612583 ,
0.13147196, -0.27174017],
[-0.56627136, -0.25801656, -0.28928643, ..., 0.18859185,
-0.46310693, 0.23317206]], dtype=float32), array([[ 0.22075664, 0.41444212, -0.4506032 , ..., -0.16359589,
-0.569371 , -0.13674185],
[-0.19642216, -0.36554584, -0.45241717, ..., -0.06579936,
-0.656618 , -0.11742882],
[-0.04687649, 0.10208729, -0.060711 , ..., 0.78385997,
-0.17697819, 0.75614196],
...,
[ 0.2150026 , 0.15821525, -0.16107246, ..., -0.20459318,
-0.20312549, 0.11470713],
[ 0.06135871, 0.06140292, -0.11083869, ..., 0.14165913,
0.05527819, 0.0520999 ],
[ 0.13505958, 0.06166949, -0.11568853, ..., -0.04749477,
0.02072733, -0.05888995]], dtype=float32), 9)]
collect() result is:
[(array([[ 0.18516347, -0.33975708, -0.46829244, ..., 0.498327 ,
-0.1628269 , 0.5171599 ],
[ 0.24843855, 0.43924475, 0.43121427, ..., -0.3605212 ,
-0.2543247 , -0.35761902],
[ 0.03349265, 0.28567392, -0.3129074 , ..., 0.30228034,
0.33539015, 0.28145155],
...,
[ 0.39538023, -0.11668223, 0.23590142, ..., -0.39222914,
-0.34792763, 0.43729994],
[-0.37299404, -0.40583754, -0.41405225, ..., 0.3708834 ,
0.6067088 , 0.5815965 ],
[-0.5297639 , 0.09037948, 0.06255247, ..., 0.55813074,
0.2599809 , 0.2930913 ]], dtype=float32), array([[-0.05441456, -0.08513332, -0.47554415, ..., 0.46570715,
0.11365455, 0.641596 ],
[ 0.25915185, 0.6784206 , 0.5428535 , ..., -0.3223893 ,
-0.17784661, -0.3021973 ],
[ 0.4007858 , 0.48166505, -0.7551351 , ..., 0.5893394 ,
0.5379706 , 0.5853663 ],
...,
[ 0.03785443, -0.02449032, 0.07482295, ..., 0.14570121,
0.02578176, 0.11021709],
[ 0.02157725, 0.20236807, -0.25889152, ..., -0.2123813 ,
0.11124409, 0.05835798],
[-0.2482176 , 0.1100187 , -0.16054511, ..., 0.15483692,
-0.01258843, -0.000899 ]], dtype=float32), 0), (array([[-0.06023056, 0.34326535, 0.01615176, ..., 0.50180113,
0.35740197, -0.3607464 ],
[-0.37374282, -0.05733229, -0.10494906, ..., 0.10067802,
-0.30181995, 0.19373518],
[-0.21093842, -0.35539758, 0.2722222 , ..., -0.13212524,
0.15457118, 0.29343936],
...,
[-0.3844776 , 0.29577827, 0.23207994, ..., -0.2748728 ,
0.05118364, -0.43278512],
[ 0.18988602, 0.15946351, -0.37208527, ..., 0.18980268,
0.26914784, 0.57002 ],
[-0.19685334, -0.00215623, -0.50676346, ..., -0.25601804,
0.43306062, -0.45977998]], dtype=float32), array([[ 0.07212897, 0.37888303, 0.14216022, ..., 0.3568563 ,
0.31809187, -0.30161127],
[-0.14504915, 0.09403719, -0.24099208, ..., 0.1194509 ,
-0.571604 , 0.3073006 ],
[-0.34506813, -0.5373435 , 0.39612344, ..., -0.17277275,
-0.15978633, 0.9480076 ],
...,
[ 0.33411705, 0.3616483 , -0.00284358, ..., -0.1570069 ,
-0.10693 , 0.11339971],
[ 0.03463184, 0.02923653, -0.07571009, ..., -0.01076985,
-0.24661644, 0.07400434],
[ 0.03500444, 0.02041529, 0.11345199, ..., -0.21085429,
-0.11910766, -0.16162656]], dtype=float32), 1), (array([[ 0.40148696, 0.44860718, -0.13891219, ..., 0.38478658,
0.0354558 , 0.5507143 ],
[ 0.35727167, 0.33523828, -0.40109852, ..., -0.12698223,
-0.48993534, 0.4485451 ],
[-0.37534824, -0.07898732, -0.28266475, ..., 0.0923453 ,
0.20398574, 0.46967202],
...,
[-0.31968838, 0.47479796, -0.49929148, ..., -0.23865293,
-0.24262336, -0.06511039],
[ 0.2098573 , -0.5782443 , 0.0044039 , ..., -0.13356705,
-0.5997722 , 0.24789433],
[ 0.18989192, -0.41790476, -0.5493083 , ..., -0.04386537,
0.14099114, -0.3851897 ]], dtype=float32), array([[ 0.35465854, 0.5298055 , 0.09639163, ..., 0.42402148,
0.1006598 , 0.42896628],
[ 0.49630994, 0.53968257, -0.45613742, ..., 0.11758496,
-0.82799774, 0.47129843],
[-0.11091773, -0.126068 , -0.94792765, ..., -0.39254257,
0.49629924, 0.90804875],
...,
[ 0.09986763, 0.18694244, 0.04551417, ..., -0.00185746,
0.04787954, 0.14079888],
[ 0.01976901, 0.01671817, 0.02434383, ..., -0.05640491,
0.03537085, -0.08094196],
[-0.12894829, 0.15826981, -0.09516723, ..., -0.11121278,
-0.17831786, -0.00805143]], dtype=float32), 2), (array([[-0.16927287, -0.3801918 , -0.32327962, ..., -0.51245123,
0.41986853, 0.18242987],
[-0.41638187, -0.06312063, -0.40284333, ..., 0.26918623,
-0.4305522 , -0.4801858 ],
[ 0.00497885, -0.22712015, -0.35257223, ..., 0.02938372,
0.32673585, 0.176891 ],
...,
[ 0.22629777, 0.39141867, 0.3272797 , ..., -0.45520803,
0.17408061, -0.27852598],
[-0.24445221, -0.35762975, -0.39768136, ..., 0.26196685,
0.17221238, 0.22423406],
[ 0.38731757, -0.45889175, 0.3848555 , ..., 0.469341 ,
0.2884723 , 0.4584588 ]], dtype=float32), array([[-0.38282454, -0.33555806, -0.38107285, ..., -0.4294134 ,
0.40941086, 0.19802842],
[-0.5812499 , -0.08969598, -0.5220119 , ..., 0.35569692,
-0.436198 , -0.78949076],
[-0.6598932 , -0.61103916, -0.8173852 , ..., 0.39758506,
0.77731544, 0.38008815],
...,
[ 0.16403008, -0.06079637, -0.02834259, ..., -0.10361861,
0.10504336, 0.21878755],
[-0.04492131, -0.17708012, -0.2025333 , ..., 0.17902693,
-0.15750957, -0.23726523],
[-0.07494884, -0.06165536, -0.10450385, ..., 0.17355536,
0.01258762, -0.05638846]], dtype=float32), 3), (array([[ 0.00608851, -0.5132972 , -0.36956814, ..., -0.5321781 ,
-0.26002124, 0.18518737],
[-0.21294856, -0.30260456, 0.11435414, ..., -0.52438897,
-0.46490332, -0.43702644],
[ 0.32582685, -0.25917143, -0.34966806, ..., -0.29239386,
-0.09908075, 0.32888958],
...,
[-0.25709772, 0.32678795, 0.00447032, ..., -0.04280795,
0.01051257, 0.36575 ],
[-0.07194624, -0.245111 , 0.5667958 , ..., 0.41409835,
-0.5799336 , -0.4858021 ],
[-0.11811656, 0.52633476, 0.40527648, ..., 0.40293694,
-0.4926284 , -0.14098077]], dtype=float32), array([[-0.11553568, -0.3299477 , -0.50079507, ..., -0.44186658,
0.01149882, 0.01778618],
[-0.3005174 , -0.26895422, 0.08186761, ..., -0.7348857 ,
-0.5810445 , -0.53310806],
[ 0.5055336 , -0.35796392, -0.90873146, ..., -0.78193223,
-0.04517283, 0.7075767 ],
...,
[-0.02022595, 0.01297146, -0.02664614, ..., 0.10258275,
-0.06243805, -0.07688553],
[-0.24571168, 0.18370496, 0.03886681, ..., -0.03063186,
-0.04676892, -0.10450852],
[ 0.03074093, -0.045911 , 0.07248624, ..., -0.05876327,
0.06366935, -0.01161662]], dtype=float32), 4), (array([[-0.28593373, 0.13395616, 0.48233178, ..., 0.508933 ,
-0.19197209, 0.3298264 ],
[-0.04904724, 0.34900156, 0.32834592, ..., -0.42706388,
-0.39813402, 0.14217453],
[ 0.13526852, 0.3745679 , 0.12265893, ..., 0.30098978,
0.15158501, 0.2164896 ],
...,
[ 0.21077481, 0.31915814, -0.32008937, ..., -0.48824465,
0.15033281, -0.55469203],
[-0.33247608, -0.05625575, 0.43155706, ..., 0.34942424,
0.04699863, 0.17167164],
[-0.32446697, 0.3883351 , -0.18434255, ..., -0.481489 ,
0.38606554, -0.31928998]], dtype=float32), array([[-0.21347475, 0.1707647 , 0.40377334, ..., 0.21519852,
-0.2119801 , 0.19800745],
[ 0.11622372, 0.4918646 , 0.4055998 , ..., -0.65552425,
-0.5142339 , 0.01027841],
[ 0.5708596 , 0.58031136, 0.5869403 , ..., 0.7259039 ,
0.10178877, 0.09589735],
...,
[-0.04035015, 0.08506791, 0.00343785, ..., 0.06102315,
0.07513183, -0.05011833],
[ 0.15481526, 0.14573523, -0.04516461, ..., -0.14253813,
-0.0463484 , -0.2259047 ],
[ 0.22652309, 0.2351952 , -0.03388155, ..., -0.04040325,
-0.17493977, -0.28690276]], dtype=float32), 5), (array([[ 0.01241684, 0.3848157 , -0.39109015, ..., 0.5478949 ,
-0.56430525, 0.6666779 ],
[ 0.14730261, -0.01858082, -0.3315102 , ..., 0.2657176 ,
-0.33017242, 0.21192974],
[ 0.33523917, -0.18178591, -0.3066008 , ..., 0.20106438,
-0.22945583, -0.03762559],
...,
[-0.2540486 , 0.18161409, -0.18843989, ..., -0.2343464 ,
-0.09964091, -0.06174641],
[-0.04365243, 0.5485354 , 0.20301312, ..., 0.46485335,
-0.5898763 , -0.32282397],
[-0.48032835, -0.44875026, 0.27545917, ..., -0.38980302,
-0.42448705, -0.28787732]], dtype=float32), array([[ 0.12385425, 0.34021133, -0.5947283 , ..., 0.5262172 ,
-0.50784826, 0.6099582 ],
[ 0.36343968, -0.00910554, -0.4031666 , ..., 0.29984346,
-0.51017505, 0.19207458],
[ 0.7371366 , -0.05033325, -0.65316683, ..., 0.6836462 ,
-0.82514614, 0.05400744],
...,
[ 0.00850144, 0.14323875, -0.17175199, ..., 0.02427784,
0.08420605, 0.08254603],
[-0.05369973, 0.1430688 , -0.18583198, ..., 0.22647178,
-0.23855914, 0.11469093],
[-0.16125312, 0.00549565, -0.1524472 , ..., 0.02182539,
-0.07359254, 0.13050811]], dtype=float32), 6), (array([[-0.04644716, 0.05705595, -0.24267559, ..., -0.21403529,
-0.06703684, 0.41887376],
[-0.29370484, -0.39780775, -0.3568661 , ..., -0.09881599,
0.07795003, 0.38119403],
[-0.3645612 , 0.10963462, -0.347853 , ..., -0.27757406,
-0.27060333, 0.3227043 ],
...,
[ 0.29615575, -0.4612898 , 0.05339388, ..., -0.08740558,
-0.25723857, -0.49486127],
[ 0.05623814, 0.4080824 , -0.24716274, ..., -0.09058192,
0.1756261 , -0.3089906 ],
[-0.5414306 , 0.4408762 , 0.42941254, ..., 0.15542302,
0.5396825 , -0.2709453 ]], dtype=float32), array([[-3.42568368e-01, 1.64543867e-01, 1.99411588e-04, ...,
-2.34075099e-01, 4.41714190e-02, 3.03957969e-01],
[-4.59636778e-01, -4.12208438e-01, -4.32887614e-01, ...,
-5.51471636e-02, 2.10178539e-01, 5.81714928e-01],
[-1.12530768e+00, 6.36167228e-01, -6.05338633e-01, ...,
-5.48392713e-01, -2.09195793e-01, 1.01132858e+00],
...,
[-1.30258754e-01, -4.16335762e-02, 1.45700663e-01, ...,
-3.17536071e-02, 3.02174967e-02, 1.10822864e-01],
[-1.39803439e-01, 3.05877943e-02, 1.05636232e-01, ...,
1.10526226e-01, 6.08176775e-02, -5.73274679e-02],
[-2.55209617e-02, 1.39828652e-01, -5.54599836e-02, ...,
1.40946835e-01, 1.95314527e-01, -7.23765837e-03]], dtype=float32), 7), (array([[-0.07237607, -0.06369619, -0.57799906, ..., -0.3011678 ,
-0.3869047 , -0.5708126 ],
[-0.00725966, -0.04352329, -0.14471681, ..., -0.47405225,
-0.11870398, 0.44799381],
[ 0.36965963, 0.19295754, -0.25880384, ..., -0.27418908,
-0.2637073 , -0.25275636],
...,
[ 0.06180732, 0.10883695, -0.20686714, ..., -0.4045689 ,
-0.10775824, -0.00597983],
[ 0.17705184, 0.2853461 , 0.38804924, ..., 0.00480051,
0.23195331, 0.5900061 ],
[ 0.33405438, -0.05846346, -0.49157378, ..., 0.13280089,
0.4277615 , -0.43489072]], dtype=float32), array([[-0.02955375, 0.03905029, -0.5314531 , ..., -0.3317717 ,
-0.32801694, -0.5832374 ],
[-0.15554425, 0.01338848, -0.30303237, ..., -0.5683995 ,
0.07281558, 0.39870417],
[ 0.57431275, 0.48248613, -0.36046496, ..., -0.740754 ,
-0.4101879 , -0.46861094],
...,
[-0.05291473, -0.00626671, -0.16552408, ..., -0.21676618,
0.00198667, 0.15837853],
[-0.20605578, 0.02732588, 0.05644984, ..., -0.14183098,
0.11027621, 0.00878341],
[ 0.01149986, 0.02788752, -0.07346534, ..., 0.02313588,
0.0365326 , 0.0427165 ]], dtype=float32), 8), (array([[ 0.32586667, 0.6187046 , 0.3162743 , ..., 0.42238498,
-0.08919488, -0.23198491],
[ 0.1516603 , 0.04449864, 0.10896807, ..., -0.49212578,
0.14955266, -0.04248938],
[-0.23296818, -0.16775374, 0.41552317, ..., -0.27849862,
0.08736038, 0.36777073],
...,
[ 0.17306753, 0.4053642 , -0.063707 , ..., 0.09530412,
-0.46045092, 0.42887986],
[-0.63025045, 0.3146556 , 0.148895 , ..., 0.48645812,
0.27349007, -0.0574788 ],
[-0.00731906, -0.03560491, -0.3711313 , ..., 0.2597622 ,
0.44751585, 0.30753264]], dtype=float32), array([[ 0.48764184, 0.54429394, 0.461604 , ..., 0.5304342 ,
0.20013283, -0.23899378],
[ 0.15820505, -0.131124 , -0.02082682, ..., -0.58288604,
0.04917484, -0.06569459],
[-0.74465793, -0.29858342, 1.2702694 , ..., -0.80466145,
0.45967343, 0.913658 ],
...,
[-0.02027037, -0.06839398, -0.14130257, ..., -0.14568064,
-0.02432712, 0.02081295],
[-0.03228899, -0.05201127, 0.1773543 , ..., 0.05981961,
-0.13131005, 0.00199337],
[-0.14305267, 0.02065172, 0.01255422, ..., 0.02369817,
0.0101981 , 0.13320476]], dtype=float32), 9)]
Any idea about what this weird behaviour? Maybe something related to byte order in Numpy ndarray's?
My fault. I forgot about Spark lazy evaluation (map() is a transformation, not an action as take() and collect() are; tranformations are not actually evaluated until an action is called).
This, combined with the fact my mapping function seems to be not idempotent (i.e. it does not return the same value upon executions) leads to following behaviour: calling take(10) actually evaluates the map() for a first time (returning a value), and calling collect() actually evaluates the map() for a second time (and returning a totally different value).
I need to quickly process a huge two-dimensional array and have already pre-marked the required data.
array([[ 0., 1., 2., 3., 4., 5. , 6. , 7.],
[ 6., 7., 8., 9., 10., 4.2, 4.3, 11.],
[ 12., 13., 14., 15., 16., 4.2, 4.3, 17.],
[ 18., 19., 20., 21., 22., 4.2, 4.3, 23.]])
array([[False, True, True, True, False, True, True , False],
[False, False, False, True, True, True, True , False],
[False, False, True, True, False, False, False, False],
[False, True, True, False, False, False, True , True ]])
I expect to sum up the data of the markers in each row of the array.But np.cumsum can't do this, I need solutions or good ideas, thanks
Expected output:
array([[ 0., 1., 3., 6., 0., 5. , 11. , 0.],
[ 0., 0., 0., 9., 19., 23.2, 27.5, 0.],
[ 0., 0., 14., 29., 0., 0, 0, 0.],
[ 0., 19., 39., 0., 0., 0, 4.3, 27.3]])
The difficulty of the solution is that each fragment cannot contain the result of the previous fragment
def mask_to_size(self,axis=-1):
if self.ndim==2:
if axis == 0:
mask = np.zeros((self.shape[0]+1,self.shape[1]), dtype=bool)
mask[:-1] = self ; mask[0] = False ; mask = mask.ravel('F')
else:
mask = np.zeros((self.shape[0],self.shape[1]+1), dtype=bool)
mask[:,0:-1]= self ;mask[:,0]=False; mask = mask.ravel('C')
else:
mask = np.zeros((self.shape[0]+1), dtype=bool)
mask[:-1] = self ; mask[0] = False
return np.diff(np.nonzero(mask[1:]!= mask[:-1])[0])[::2].astype(int)
# https://stackoverflow.com/a/49179628/ by #Divakar
def intervaled_cumsum(ar, sizes):
out = ar.copy()
arc = ar.cumsum() ; idx = sizes.cumsum()
out[idx[0]] = ar[idx[0]] - arc[idx[0]-1]
out[idx[1:-1]] = ar[idx[1:-1]] - np.diff(arc[idx[:-1]-1])
return out.cumsum()
def cumsum_masked(self,mask,axis=-1):
sizes = mask_to_size(mask,axis);out = np.zeros(self.size);shape = self.shape
if len(shape)==2:
if axis == 0:
mask = mask.ravel('F') ; self = self.ravel('F')
else:
mask = mask.ravel('C') ; self = self.ravel('C')
out[mask] = intervaled_cumsum(self[mask],sizes)
if len(shape)==2:
if axis == 0:
return out.reshape(shape[1],shape[0]).T
else:
return out.reshape(shape)
return out
cumsum_masked(a,m,axis=1)
I sorted out the answers and tried to optimize the speed but it didn't work.I think other people may need it.
There's intervaled_cumsum for 1D arrays. For this case, we simply need to get the masked elements and setup their island lengths and feed it to that function.
Hence, one vectorized approach would be -
# https://stackoverflow.com/a/49179628/ by #Divakar
def intervaled_cumsum(ar, sizes):
# Make a copy to be used as output array
out = ar.copy()
# Get cumumlative values of array
arc = ar.cumsum()
# Get cumsumed indices to be used to place differentiated values into
# input array's copy
idx = sizes.cumsum()
# Place differentiated values that when cumumlatively summed later on would
# give us the desired intervaled cumsum
out[idx[0]] = ar[idx[0]] - arc[idx[0]-1]
out[idx[1:-1]] = ar[idx[1:-1]] - np.diff(arc[idx[:-1]-1])
return out.cumsum()
def intervaled_cumsum_masked_rowwise(a, mask):
z = np.zeros((mask.shape[0],1), dtype=bool)
maskz = np.hstack((z,mask,z))
out = np.zeros_like(a)
sizes = np.diff(np.flatnonzero(maskz[:,1:] != maskz[:,:-1]))[::2]
out[mask] = intervaled_cumsum(a[mask], sizes)
return out
Sample run -
In [95]: a
Out[95]:
array([[ 0. , 1. , 2. , 3. , 4. , 5. , 6. , 7. ],
[ 6. , 7. , 8. , 9. , 10. , 4.2, 4.3, 11. ],
[12. , 13. , 14. , 15. , 16. , 4.2, 4.3, 17. ],
[18. , 19. , 20. , 21. , 22. , 4.2, 4.3, 23. ]])
In [96]: mask
Out[96]:
array([[False, True, True, True, False, True, True, False],
[False, False, False, True, True, True, True, False],
[False, False, True, True, False, False, False, False],
[False, True, True, False, False, False, True, True]])
In [97]: intervaled_cumsum_masked_rowwise(a, mask)
Out[97]:
array([[ 0. , 1. , 3. , 6. , 0. , 5. , 11. , 0. ],
[ 0. , 0. , 0. , 9. , 19. , 23.2, 27.5, 0. ],
[ 0. , 0. , 14. , 29. , 0. , 0. , 0. , 0. ],
[ 0. , 19. , 39. , 0. , 0. , 0. , 4.3, 27.3]])
Works just as well with negative numbers -
In [109]: a = -a
In [110]: a
Out[110]:
array([[ -0. , -1. , -2. , -3. , -4. , -5. , -6. , -7. ],
[ -6. , -7. , -8. , -9. , -10. , -4.2, -4.3, -11. ],
[-12. , -13. , -14. , -15. , -16. , -4.2, -4.3, -17. ],
[-18. , -19. , -20. , -21. , -22. , -4.2, -4.3, -23. ]])
In [111]: intervaled_cumsum_masked_rowwise(a, mask)
Out[111]:
array([[ 0. , -1. , -3. , -6. , 0. , -5. , -11. , 0. ],
[ 0. , 0. , 0. , -9. , -19. , -23.2, -27.5, 0. ],
[ 0. , 0. , -14. , -29. , 0. , 0. , 0. , 0. ],
[ 0. , -19. , -39. , 0. , 0. , 0. , -4.3, -27.3]])
Here is an approach that is quite a bit slower than #Divakar's and #filippo's but more robust. The problem with "global cumsummy" approaches is that they can suffer from loss of significance, see below:
import numpy as np
from scipy import linalg
def cumsums(data, mask, break_lines=True):
dr = data[mask]
if break_lines:
msk = mask.copy()
msk[:, 0] = False
mr = msk.ravel()[1:][mask.ravel()[:-1]][:dr.size-1]
else:
mr = mask.ravel()[1:][mask.ravel()[:-1]][:dr.size-1]
D = np.empty((2, dr.size))
D.T[...] = 1, 0
D[1, :-1] -= mr
out = np.zeros_like(data)
out[mask] = linalg.solve_banded((1, 0), D, dr)
return out
def f_staircase(a, m):
return np.cumsum(a, axis=1) - np.maximum.accumulate(np.cumsum(a, axis=1)*~m, axis=1)
# https://stackoverflow.com/a/49179628/ by #Divakar
def intervaled_cumsum(ar, sizes):
# Make a copy to be used as output array
out = ar.copy()
# Get cumumlative values of array
arc = ar.cumsum()
# Get cumsumed indices to be used to place differentiated values into
# input array's copy
idx = sizes.cumsum()
# Place differentiated values that when cumumlatively summed later on would
# give us the desired intervaled cumsum
out[idx[0]] = ar[idx[0]] - arc[idx[0]-1]
out[idx[1:-1]] = ar[idx[1:-1]] - np.diff(arc[idx[:-1]-1])
return out.cumsum()
def intervaled_cumsum_masked_rowwise(a, mask):
z = np.zeros((mask.shape[0],1), dtype=bool)
maskz = np.hstack((z,mask,z))
out = np.zeros_like(a)
sizes = np.diff(np.flatnonzero(maskz[:,1:] != maskz[:,:-1]))[::2]
out[mask] = intervaled_cumsum(a[mask], sizes)
return out
data = np.array([[ 0., 1., 2., 3., 4., 5. , 6. , 7.],
[ 6., 7., 8., 9., 10., 4.2, 4.3, 11.],
[ 12., 13., 14., 15., 16., 4.2, 4.3, 17.],
[ 18., 19., 20., 21., 22., 4.2, 4.3, 23.]])
mask = np.array([[False, True, True, True, False, True, True , False],
[False, False, False, True, True, True, True , False],
[False, False, True, True, False, False, False, False],
[False, True, True, False, False, False, True , True ]])
from timeit import timeit
print('fast?')
print('filippo', timeit(lambda: f_staircase(data, mask), number=1000))
print('pp ', timeit(lambda: cumsums(data, mask), number=1000))
print('divakar', timeit(lambda: intervaled_cumsum_masked_rowwise(data, mask), number=1000))
data = np.random.uniform(-10, 10, (5000, 5000))
mask = np.random.random((5000, 5000)) < 0.125
mask[:, 1:] |= mask[:, :-1]
mask[:, 2:] |= mask[:, :-2]
print()
print('fast on large data?')
print('filippo', timeit(lambda: f_staircase(data, mask), number=3))
print('pp ', timeit(lambda: cumsums(data, mask), number=3))
print('divakar', timeit(lambda: intervaled_cumsum_masked_rowwise(data, mask), number=3))
data = np.random.uniform(-10, 10, (10000, 10000))
mask = np.random.random((10000, 10000)) < 0.025
mask[:, 1:] |= mask[:, :-1]
mask[:, 2:] |= mask[:, :-2]
print()
print('fast on large sparse data?')
print('filippo', timeit(lambda: f_staircase(data, mask), number=3))
print('pp ', timeit(lambda: cumsums(data, mask), number=3))
print('divakar', timeit(lambda: intervaled_cumsum_masked_rowwise(data, mask), number=3))
data = np.exp(-np.linspace(-24, 24, 100))[None]
mask = (np.arange(100) % 4).astype(bool)[None]
print()
print('numerically sound?')
print('correct', data[0, -3:].sum())
print('filippo', f_staircase(data, mask)[0,-1])
print('pp ', cumsums(data, mask)[0,-1])
print('divakar', intervaled_cumsum_masked_rowwise(data, mask)[0,-1])
Output:
fast?
filippo 0.008435532916337252
pp 0.07329772273078561
divakar 0.0336935929954052
fast on large data?
filippo 1.6037923698313534
pp 3.982803522143513
divakar 1.706403402145952
fast on large sparse data?
filippo 6.11361704999581
pp 4.717669038102031
divakar 2.9474888620898128
numerically sound?
correct 1.9861262739950047e-10
filippo 0.0
pp 1.9861262739950047e-10
divakar 9.737630365237156e-06
We see that with the falling exponential example the cumsum based approaches don't work. Obviously, this is an engineered example, but it showcases a real problem.
Here's an attempt to implement #hpaulj suggestion
>>> a = np.array([[ 0. , 1. , 2. , 3. , 4. , 5. , 6. , 7. ],
... [ 6. , 7. , 8. , 9. , 10. , 4.2, 4.3, 11. ],
... [12. , 13. , 14. , 15. , 16. , 4.2, 4.3, 17. ],
... [18. , 19. , 20. , 21. , 22. , 4.2, 4.3, 23. ]])
>>> m = np.array([[False, True, True, True, False, True, True, False],
... [False, False, False, True, True, True, True, False],
... [False, False, True, True, False, False, False, False],
... [False, True, True, False, False, False, True, True]])
>>> np.maximum.accumulate(np.cumsum(a, axis=1)*~m, axis=1)
array([[ 0. , 0. , 0. , 0. , 10. , 10. , 10. , 28. ],
[ 6. , 13. , 21. , 21. , 21. , 21. , 21. , 59.5],
[ 12. , 25. , 25. , 25. , 70. , 74.2, 78.5, 95.5],
[ 18. , 18. , 18. , 78. , 100. , 104.2, 104.2, 104.2]])
>>> np.cumsum(a, axis=1) - np.maximum.accumulate(np.cumsum(a, axis=1)*~m, axis=1)
array([[ 0. , 1. , 3. , 6. , 0. , 5. , 11. , 0. ],
[ 0. , 0. , 0. , 9. , 19. , 23.2, 27.5, 0. ],
[ 0. , 0. , 14. , 29. , 0. , 0. , 0. , 0. ],
[ 0. , 19. , 39. , 0. , 0. , 0. , 4.3, 27.3]])
See also Most efficient way to forward-fill NaN values in numpy array which seems somewhat related, especially if your array is not >= 0 like in this toy example, the approved answer there should be helpful.
EDIT
For future reference here's a version that removes the above >= 0 assumption. Should still be pretty fast, didn't benchmark it against the other methods though.
In [38]: def masked_cumsum(a, m):
...: idx = np.maximum.accumulate(np.where(m, 0, np.arange(m.size).reshape(m.shape)), axis=1)
...: c = np.cumsum(a, axis=-1)
...: return c - c[np.unravel_index(idx, m.shape)]
...:
In [43]: masked_cumsum(-a, m)
Out[43]:
array([[ 0. , -1. , -3. , -6. , 0. , -5. , -11. , 0. ],
[ 0. , 0. , 0. , -9. , -19. , -23.2, -27.5, 0. ],
[ 0. , 0. , -14. , -29. , 0. , 0. , 0. , 0. ],
[ 0. , -19. , -39. , 0. , 0. , 0. , -4.3, -27.3]])
The equation is
$C'_{ijkl} = Q_{im} Q_{jn} C_{mnop} (Q^{-1})_{ok} (Q^{-1})_{pl}$
I was able to use
np.einsum('im,jn,mnop,ok,pl', Q, Q, C, Q_inv, Q_inv)
to do the job, and also expect
np.tensordot(np.tensordot(np.tensordot(Q, np.tensordot(Q, C, axes=[1,1]), axes=[1,0]), Q_inv, axes=[2,0]), Q_inv, axes=[3,0])
to work, but it doesn't.
Specifics:
C is a 4th-order elastic tensor:
array([[[[ 552.62389047, -0.28689554, -0.32194701],
[ -0.28689554, 118.89168597, -0.65559912],
[ -0.32194701, -0.65559912, 130.21758722]],
[[ -0.28689554, 166.02923119, -0.00000123],
[ 166.02923119, 0.49494431, -0.00000127],
[ -0.00000123, -0.00000127, -0.57156702]],
[[ -0.32194701, -0.00000123, 165.99413061],
[ -0.00000123, -0.64666809, -0.0000013 ],
[ 165.99413061, -0.0000013 , 0.42997465]]],
[[[ -0.28689554, 166.02923119, -0.00000123],
[ 166.02923119, 0.49494431, -0.00000127],
[ -0.00000123, -0.00000127, -0.57156702]],
[[ 118.89168597, 0.49494431, -0.64666809],
[ 0.49494431, 516.15898907, -0.33132485],
[ -0.64666809, -0.33132485, 140.09010389]],
[[ -0.65559912, -0.00000127, -0.0000013 ],
[ -0.00000127, -0.33132485, 165.98553869],
[ -0.0000013 , 165.98553869, 0.41913346]]],
[[[ -0.32194701, -0.00000123, 165.99413061],
[ -0.00000123, -0.64666809, -0.0000013 ],
[ 165.99413061, -0.0000013 , 0.42997465]],
[[ -0.65559912, -0.00000127, -0.0000013 ],
[ -0.00000127, -0.33132485, 165.98553869],
[ -0.0000013 , 165.98553869, 0.41913346]],
[[ 130.21758722, -0.57156702, 0.42997465],
[ -0.57156702, 140.09010389, 0.41913346],
[ 0.42997465, 0.41913346, 486.62412063]]]])
Q is a rotation matrix changing x and y coords.
array([[ 0, 1, 0],
[-1, 0, 0],
[ 0, 0, 1]])
Q_inv is
array([[-0., -1., -0.],
[ 1., 0., 0.],
[ 0., 0., 1.]])
np.einsum leads to
array([[[[ 516.15898907, -0.49494431, -0.33132485],
[ -0.49494431, 118.89168597, 0.64666809],
[ -0.33132485, 0.64666809, 140.09010389]],
[[ -0.49494431, 166.02923119, 0.00000127],
[ 166.02923119, 0.28689554, -0.00000123],
[ 0.00000127, -0.00000123, 0.57156702]],
[[ -0.33132485, 0.00000127, 165.98553869],
[ 0.00000127, -0.65559912, 0.0000013 ],
[ 165.98553869, 0.0000013 , 0.41913346]]],
[[[ -0.49494431, 166.02923119, 0.00000127],
[ 166.02923119, 0.28689554, -0.00000123],
[ 0.00000127, -0.00000123, 0.57156702]],
[[ 118.89168597, 0.28689554, -0.65559912],
[ 0.28689554, 552.62389047, 0.32194701],
[ -0.65559912, 0.32194701, 130.21758722]],
[[ 0.64666809, -0.00000123, 0.0000013 ],
[ -0.00000123, 0.32194701, 165.99413061],
[ 0.0000013 , 165.99413061, -0.42997465]]],
[[[ -0.33132485, 0.00000127, 165.98553869],
[ 0.00000127, -0.65559912, 0.0000013 ],
[ 165.98553869, 0.0000013 , 0.41913346]],
[[ 0.64666809, -0.00000123, 0.0000013 ],
[ -0.00000123, 0.32194701, 165.99413061],
[ 0.0000013 , 165.99413061, -0.42997465]],
[[ 140.09010389, 0.57156702, 0.41913346],
[ 0.57156702, 130.21758722, -0.42997465],
[ 0.41913346, -0.42997465, 486.62412063]]]])
which I believe is correct, while four np.tensordot leads to
array([[[[ 552.62389047, -0.28689554, 0.32194701],
[ -0.28689554, 118.89168597, 0.65559912],
[ -0.32194701, -0.65559912, -130.21758722]],
[[ -0.28689554, 166.02923119, 0.00000123],
[ 166.02923119, 0.49494431, 0.00000127],
[ -0.00000123, -0.00000127, 0.57156702]],
[[ -0.32194701, -0.00000123, -165.99413061],
[ -0.00000123, -0.64666809, 0.0000013 ],
[ 165.99413061, -0.0000013 , -0.42997465]]],
[[[ -0.28689554, 166.02923119, 0.00000123],
[ 166.02923119, 0.49494431, 0.00000127],
[ -0.00000123, -0.00000127, 0.57156702]],
[[ 118.89168597, 0.49494431, 0.64666809],
[ 0.49494431, 516.15898907, 0.33132485],
[ -0.64666809, -0.33132485, -140.09010389]],
[[ -0.65559912, -0.00000127, 0.0000013 ],
[ -0.00000127, -0.33132485, -165.98553869],
[ -0.0000013 , 165.98553869, -0.41913346]]],
[[[ 0.32194701, 0.00000123, 165.99413061],
[ 0.00000123, 0.64666809, -0.0000013 ],
[-165.99413061, 0.0000013 , 0.42997465]],
[[ 0.65559912, 0.00000127, -0.0000013 ],
[ 0.00000127, 0.33132485, 165.98553869],
[ 0.0000013 , -165.98553869, 0.41913346]],
[[-130.21758722, 0.57156702, 0.42997465],
[ 0.57156702, -140.09010389, 0.41913346],
[ -0.42997465, -0.41913346, 486.62412063]]]])
Notice the negative big numbers.
Approach #1
One way would be to use np.tensordot to get the same result as with np.einsum though not in a single step and with some help from the trusty broadcasting -
# Get broadcasted elementwise multiplication between two versions of Q.
# This corresponds to "np.einsum('im,jn,..', Q, Q)" producing "'ijmn""
# broadcasted version of elementwise multiplications between Q's.
Q_ext = Q[:,None,:,None]*Q[:,None,:]
# Similarly for Q_inv : For "np.einsum('..ok,pl', Q_inv, Q_inv)" get "'opkl'"
# broadcasted version of elementwise multiplications between Q_inv's.
Q_inv_ext = Q_inv[:,None,:,None]*Q_inv[:,None,:]
# Perform "np.einsum('im,jn,mnop,ok,pl', Q, Q, C)" with "np.tensordot".
# Notice that we are using the last two axes from 'Q_ext', so "axes=[2,3]"
# and first two from 'C', so "axes=[0,1]" for it.
# These axes would be reduced by the dot-product, leaving us with 'ijop'.
parte1 = np.tensordot(Q_ext,C,axes=([2,3],[0,1]))
# Do it one more time to perform "np.einsum('ijop,ok,pl', parte1,Q_inv,Q_inv)"
# to reduce dimensions represented by 'o,p', leaving us with 'ijkl'.
# To confirm, compare the following against original einsum approach :
# "np.einsum('im,jn,mnop,ok,pl->ijkl', Q, Q, C, Q_inv, Q_inv)"
out = np.tensordot(parte1,Q_inv_ext,axes=([2,3],[0,1]))
Approach #2
If you wish to avoid broadcasting in favour of using two more instances of np.tensordot, you could do -
# Perform "np.einsum('jn,mnop', Q, C). Notice how, Q is represented by 'jn'
# and C by 'mnop'. We need to reduce the 'm' dimension, i.e. reduce 'axes=1'
# from Q and `axes=1` from C corresponding to `n' in each of the inputs.
# Thus, 'jn' + 'mnop' => 'jmop' after 'n' is reduced and order is maintained.
Q_C1 = np.tensordot(Q,C,axes=([1],[1]))
# Perform "np.einsum('im,jn,mnop', Q, Q, C). We need to use Q and Q_C1.
# Q is 'im' and Q_C1 is 'jmop'. Thus, again we need to reduce 'axes=1'
# from Q and `axes=1` from Q_C1 corresponding to `m' in each of the inputs.
# Thus, 'im' + 'jmop' => 'ijop' after 'm' is reduced and order is maintained.
parte1 = np.tensordot(Q,Q_C1,axes=([1],[1]))
# Use the same philosophy to get the rest of the einsum equivalent,
# but use parte1 and go right and use Q_inv
out = np.tensordot(np.tensordot(parte1,Q_inv,axes=([2],[0])),Q_inv,axes=([2],[0]))
The trick with np.tensordot is to keep track of the dimensions that are reduced by the axes parameter and how the collapsed dimensions align against the remaining inputs' dimensions.