@@ -199,6 +199,30 @@ void bufferize(void *cptr, DTypeId dtype, const int64_t *sizes,
199199 });
200200}
201201
202+ // / copy contiguous block of data into a possibly strided tensor
203+ void unpack (void *in, DTypeId dtype, const int64_t *sizes,
204+ const int64_t *strides, const int64_t *tStarts,
205+ const int64_t *tSizes, uint64_t nd, uint64_t N, void *out) {
206+ dispatch (dtype, out, [sizes, strides, tStarts, tSizes, nd, N, in](auto *ptr) {
207+ auto buff = static_cast <decltype (ptr)>(in);
208+
209+ for (auto i = 0 ; i < N; ++i) {
210+ auto szs = &tSizes[i * nd];
211+ if (szs[0 ] > 0 ) {
212+ auto sts = &tStarts[i * nd];
213+ uint64_t off = 0 ;
214+ for (int64_t r = 0 ; r < nd; ++r) {
215+ off += sts[r] * strides[r];
216+ }
217+ forall (0 , &ptr[off], szs, strides, nd, [&buff](auto *out) {
218+ *out = *buff;
219+ ++buff;
220+ });
221+ }
222+ }
223+ });
224+ }
225+
202226template <typename T>
203227void copy_ (uint64_t d, uint64_t &pos, T *cptr, const int64_t *sizes,
204228 const int64_t *strides, const uint64_t *chunks, uint64_t nd,
@@ -430,7 +454,9 @@ TYPED_RESHAPE(i1, bool);
430454void _idtr_update_halo (DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
431455 int64_t *ownedShape, int64_t *ownedStride,
432456 int64_t *bbOff, int64_t *bbShape, void *ownedData,
433- void *leftHaloData, void *rightHaloData,
457+ int64_t *leftHaloShape, int64_t *leftHaloStride,
458+ void *leftHaloData, int64_t *rightHaloShape,
459+ int64_t *rightHaloStride, void *rightHaloData,
434460 Transceiver *tc) {
435461
436462#ifdef NO_TRANSCEIVER
@@ -478,7 +504,6 @@ void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
478504 bool bufferizeSend = (!is_contiguous (ownedShape, ownedStride, ndims) ||
479505 bbTotCols != ownedTotCols);
480506
481- // assert(!bufferizeSend);
482507 std::vector<int64_t > lBufferStart (nworkers * ndims, 0 );
483508 std::vector<int64_t > lBufferSize (nworkers * ndims, 0 );
484509 std::vector<int64_t > rBufferStart (nworkers * ndims, 0 );
@@ -552,24 +577,70 @@ void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
552577 rRecvOff[i] = rRecvOff[i - 1 ] + rRecvSize[i - 1 ];
553578 }
554579
555- // communicate left/right halos
580+ // receive buffering
581+ void *lRecvData, *rRecvData;
582+ void *sendData;
583+ bool bufferizeLRecv = !is_contiguous (leftHaloShape, leftHaloStride, ndims);
584+ bool bufferizeRRecv = !is_contiguous (rightHaloShape, rightHaloStride, ndims);
585+ std::vector<int64_t > recvBufferStart (nworkers * ndims, 0 );
586+ std::vector<int64_t > lRecvBufferSize (nworkers * ndims, 0 );
587+ std::vector<int64_t > rRecvBufferSize (nworkers * ndims, 0 );
588+
589+ // deduce receive shape for unpack
590+ int64_t lTotalRecvSize = 0 , rTotalRecvSize = 0 ;
591+ for (auto i = 0 ; i < nworkers; ++i) {
592+ if (bufferizeLRecv && lRecvSize[i] != 0 ) {
593+ lTotalRecvSize += lRecvSize[i];
594+ lRecvBufferSize[i * ndims] = lRecvSize[i] / bbTotCols; // nrows
595+ for (auto j = 1 ; j < ndims; ++j) {
596+ lRecvBufferSize[i * ndims + j] = bbShape[j]; // leftHaloShape[j]
597+ }
598+ }
599+ if (bufferizeRRecv && rRecvSize[i] != 0 ) {
600+ rTotalRecvSize += rRecvSize[i];
601+ rRecvBufferSize[i * ndims] = rRecvSize[i] / bbTotCols; // nrows
602+ for (auto j = 1 ; j < ndims; ++j) {
603+ rRecvBufferSize[i * ndims + j] = bbShape[j]; // rightHaloShape[j]
604+ }
605+ }
606+ }
607+
608+ Buffer recvBuff;
609+ Buffer sendBuff;
610+ if (bufferizeLRecv || bufferizeLRecv) {
611+ recvBuff.resize (std::max (lTotalRecvSize, rTotalRecvSize) *
612+ sizeof_dtype (ddpttype));
613+ }
556614 if (bufferizeSend) {
557- Buffer sendBuff;
558615 sendBuff.resize (std::max (lTotalSendSize, rTotalSendSize) *
559616 sizeof_dtype (ddpttype));
617+ }
618+ lRecvData = bufferizeLRecv ? recvBuff.data () : leftHaloData;
619+ rRecvData = bufferizeRRecv ? recvBuff.data () : rightHaloData;
620+ sendData = bufferizeSend ? sendBuff.data () : ownedData;
621+
622+ // communicate left/right halos
623+ if (bufferizeSend) {
560624 bufferize (ownedData, ddpttype, ownedShape, ownedStride, lBufferStart.data (),
561625 lBufferSize.data (), ndims, nworkers, sendBuff.data ());
562- tc->alltoall (sendBuff.data (), lSendSize.data (), lSendOff.data (), ddpttype,
563- leftHaloData, lRecvSize.data (), lRecvOff.data ());
626+ }
627+ tc->alltoall (sendData, lSendSize.data (), lSendOff.data (), ddpttype, lRecvData,
628+ lRecvSize.data (), lRecvOff.data ());
629+ if (bufferizeLRecv) {
630+ unpack (lRecvData, ddpttype, leftHaloShape, leftHaloStride,
631+ recvBufferStart.data (), lRecvBufferSize.data (), ndims, nworkers,
632+ leftHaloData);
633+ }
634+ if (bufferizeSend) {
564635 bufferize (ownedData, ddpttype, ownedShape, ownedStride, rBufferStart.data (),
565636 rBufferSize.data (), ndims, nworkers, sendBuff.data ());
566- tc-> alltoall (sendBuff. data (), rSendSize. data (), rSendOff. data (), ddpttype,
567- rightHaloData, rRecvSize .data (), rRecvOff .data ());
568- } else {
569- tc-> alltoall (ownedData, lSendSize. data (), lSendOff. data (), ddpttype,
570- leftHaloData, lRecvSize. data (), lRecvOff. data ());
571- tc-> alltoall (ownedData, rSendSize .data (), rSendOff .data (), ddpttype ,
572- rightHaloData, rRecvSize. data (), rRecvOff. data () );
637+ }
638+ tc-> alltoall (sendData, rSendSize .data (), rSendOff .data (), ddpttype, rRecvData,
639+ rRecvSize. data (), rRecvOff. data ());
640+ if (bufferizeRRecv) {
641+ unpack (rRecvData, ddpttype, rightHaloShape, rightHaloStride,
642+ recvBufferStart .data (), rRecvBufferSize .data (), ndims, nworkers ,
643+ rightHaloData);
573644 }
574645}
575646
@@ -593,13 +664,11 @@ void _idtr_update_halo(Transceiver *tc, int64_t gShapeRank, void *gShapeDescr,
593664 UnrankedMemRefType<T> leftHalo (lHaloRank, lHaloDescr);
594665 UnrankedMemRefType<T> rightHalo (rHaloRank, rHaloDescr);
595666
596- assert (leftHalo.contiguous_layout ());
597- assert (rightHalo.contiguous_layout ());
598-
599667 _idtr_update_halo (ddpttype, ownedData.rank (), ownedOff.data (),
600668 ownedData.sizes (), ownedData.strides (), bbOff.data (),
601- bbShape.data (), ownedData.data (), leftHalo.data (),
602- rightHalo.data (), tc);
669+ bbShape.data (), ownedData.data (), leftHalo.sizes (),
670+ leftHalo.strides (), leftHalo.data (), rightHalo.sizes (),
671+ rightHalo.strides (), rightHalo.data (), tc);
603672}
604673
605674extern " C" {
0 commit comments