@@ -30,7 +30,33 @@ template <typename T> T *mr_to_ptr(void *ptr, intptr_t offset) {
3030 return reinterpret_cast <T *>(ptr) + offset; // &mr.aligned[mr.offset]
3131}
3232
33+ // abstract handle providing an abstract wait method
34+ struct WaitHandleBase {
35+ virtual ~WaitHandleBase (){};
36+ virtual void wait () = 0;
37+ };
38+
39+ // concrete handle to be instantiated with a lambda or alike
40+ // the lambda will be executed within wait()
41+ template <typename T> class WaitHandle : public WaitHandleBase {
42+ T _fini;
43+
44+ public:
45+ WaitHandle (T fini) : _fini(fini) {}
46+ virtual void wait () override { _fini (); }
47+ };
48+
49+ template <typename T> WaitHandle<T> *mkWaitHandle (T fini) {
50+ return new WaitHandle<T>(fini);
51+ };
52+
3353extern " C" {
54+ void _idtr_wait (WaitHandleBase *handle) {
55+ if (handle) {
56+ handle->wait ();
57+ delete handle;
58+ }
59+ }
3460
3561#define NO_TRANSCEIVER
3662#ifdef NO_TRANSCEIVER
@@ -401,8 +427,9 @@ void _idtr_reshape(DTypeId ddpttype, int64_t lRank, int64_t *gShapePtr,
401427 Buffer outbuff (totSSz * sizeof_dtype (ddpttype), 2 ); // FIXME debug value
402428 bufferizeN (lDataPtr, ddpttype, lShapePtr, lStridesPtr, lsOffs.data (),
403429 lsEnds.data (), lRank, N, outbuff.data ());
404- tc->alltoall (outbuff.data (), sszs.data (), soffs.data (), ddpttype, oDataPtr,
405- rszs.data (), roffs.data ());
430+ auto hdl = tc->alltoall (outbuff.data (), sszs.data (), soffs.data (), ddpttype,
431+ oDataPtr, rszs.data (), roffs.data ());
432+ tc->wait (hdl);
406433}
407434
408435// / @brief reshape tensor
@@ -451,21 +478,22 @@ TYPED_RESHAPE(i1, bool);
451478// / @brief Update data in halo parts
452479// / We assume tensor is partitioned along the first dimension only
453480// / (row partitioning) and partitions are ordered by ranks
454- void _idtr_update_halo (DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
455- int64_t *ownedShape, int64_t *ownedStride,
456- int64_t *bbOff, int64_t *bbShape, void *ownedData,
457- int64_t *leftHaloShape, int64_t *leftHaloStride,
458- void *leftHaloData, int64_t *rightHaloShape,
459- int64_t *rightHaloStride, void *rightHaloData,
460- Transceiver *tc) {
481+ // / @return (MPI) handles
482+ void *_idtr_update_halo (DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
483+ int64_t *ownedShape, int64_t *ownedStride,
484+ int64_t *bbOff, int64_t *bbShape, void *ownedData,
485+ int64_t *leftHaloShape, int64_t *leftHaloStride,
486+ void *leftHaloData, int64_t *rightHaloShape,
487+ int64_t *rightHaloStride, void *rightHaloData,
488+ Transceiver *tc) {
461489
462490#ifdef NO_TRANSCEIVER
463491 initMPIRuntime ();
464492 tc = getTransceiver ();
465493#endif
466494 auto nworkers = tc->nranks ();
467495 if (nworkers <= 1 || getenv (" DDPT_SKIP_COMM" ))
468- return ;
496+ return nullptr ;
469497 auto myWorkerIndex = tc->rank ();
470498
471499 // Gather table with bounding box offsets and shapes for all workers
@@ -582,7 +610,6 @@ void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
582610 void *sendData;
583611 bool bufferizeLRecv = !is_contiguous (leftHaloShape, leftHaloStride, ndims);
584612 bool bufferizeRRecv = !is_contiguous (rightHaloShape, rightHaloStride, ndims);
585- std::vector<int64_t > recvBufferStart (nworkers * ndims, 0 );
586613 std::vector<int64_t > lRecvBufferSize (nworkers * ndims, 0 );
587614 std::vector<int64_t > rRecvBufferSize (nworkers * ndims, 0 );
588615
@@ -624,35 +651,50 @@ void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
624651 bufferize (ownedData, ddpttype, ownedShape, ownedStride, lBufferStart.data (),
625652 lBufferSize.data (), ndims, nworkers, sendBuff.data ());
626653 }
627- tc->alltoall (sendData, lSendSize.data (), lSendOff.data (), ddpttype, lRecvData,
628- lRecvSize.data (), lRecvOff.data ());
629- if (bufferizeLRecv) {
630- unpack (lRecvData, ddpttype, leftHaloShape, leftHaloStride,
631- recvBufferStart.data (), lRecvBufferSize.data (), ndims, nworkers,
632- leftHaloData);
633- }
654+ auto lwh = tc->alltoall (sendData, lSendSize.data (), lSendOff.data (), ddpttype,
655+ lRecvData, lRecvSize.data (), lRecvOff.data ());
634656 if (bufferizeSend) {
635657 bufferize (ownedData, ddpttype, ownedShape, ownedStride, rBufferStart.data (),
636658 rBufferSize.data (), ndims, nworkers, sendBuff.data ());
637659 }
638- tc->alltoall (sendData, rSendSize.data (), rSendOff.data (), ddpttype, rRecvData,
639- rRecvSize.data (), rRecvOff.data ());
640- if (bufferizeRRecv) {
641- unpack (rRecvData, ddpttype, rightHaloShape, rightHaloStride,
642- recvBufferStart.data (), rRecvBufferSize.data (), ndims, nworkers,
643- rightHaloData);
660+ auto rwh = tc->alltoall (sendData, rSendSize.data (), rSendOff.data (), ddpttype,
661+ rRecvData, rRecvSize.data (), rRecvOff.data ());
662+
663+ auto wait = [=, _lRecvBufferSize = std::move (lRecvBufferSize),
664+ _rRecvBufferSize = std::move (rRecvBufferSize)]() {
665+ tc->wait (lwh);
666+ std::vector<int64_t > recvBufferStart (nworkers * ndims, 0 );
667+ if (bufferizeLRecv) {
668+ unpack (lRecvData, ddpttype, leftHaloShape, leftHaloStride,
669+ recvBufferStart.data (), _lRecvBufferSize.data (), ndims, nworkers,
670+ leftHaloData);
671+ }
672+ tc->wait (rwh);
673+ if (bufferizeRRecv) {
674+ unpack (rRecvData, ddpttype, rightHaloShape, rightHaloStride,
675+ recvBufferStart.data (), _rRecvBufferSize.data (), ndims, nworkers,
676+ rightHaloData);
677+ }
678+ };
679+ assert (lRecvBufferSize.empty () && rRecvBufferSize.empty ());
680+
681+ // FIXME (in imex) buffer-dealloc pass deallocs halo strides and sizes
682+ if (bufferizeLRecv || bufferizeRRecv || getenv (" DDPT_NO_ASYNC" )) {
683+ wait ();
684+ return nullptr ;
644685 }
686+ return mkWaitHandle (wait);
645687}
646688
647689// / @brief templated wrapper for typed function versions calling
648690// / _idtr_update_halo
649691template <typename T>
650- void _idtr_update_halo (Transceiver *tc, int64_t gShapeRank , void *gShapeDescr ,
651- int64_t oOffRank, void *oOffDescr, int64_t oDataRank,
652- void *oDataDescr, int64_t bbOffRank, void *bbOffDescr,
653- int64_t bbShapeRank, void *bbShapeDescr,
654- int64_t lHaloRank, void *lHaloDescr, int64_t rHaloRank,
655- void *rHaloDescr) {
692+ void * _idtr_update_halo (Transceiver *tc, int64_t gShapeRank , void *gShapeDescr ,
693+ int64_t oOffRank, void *oOffDescr, int64_t oDataRank,
694+ void *oDataDescr, int64_t bbOffRank, void *bbOffDescr,
695+ int64_t bbShapeRank, void *bbShapeDescr,
696+ int64_t lHaloRank, void *lHaloDescr, int64_t rHaloRank,
697+ void *rHaloDescr) {
656698
657699 auto ddpttype = DTYPE<T>::value;
658700
@@ -664,25 +706,25 @@ void _idtr_update_halo(Transceiver *tc, int64_t gShapeRank, void *gShapeDescr,
664706 UnrankedMemRefType<T> leftHalo (lHaloRank, lHaloDescr);
665707 UnrankedMemRefType<T> rightHalo (rHaloRank, rHaloDescr);
666708
667- _idtr_update_halo (ddpttype, ownedData. rank (), ownedOff. data (),
668- ownedData.sizes (), ownedData. strides (), bbOff. data (),
669- bbShape .data (), ownedData .data (), leftHalo. sizes (),
670- leftHalo.strides (), leftHalo.data (), rightHalo.sizes (),
671- rightHalo.strides (), rightHalo.data (), tc);
709+ return _idtr_update_halo (
710+ ddpttype, ownedData.rank (), ownedOff. data (), ownedData. sizes (),
711+ ownedData. strides (), bbOff .data (), bbShape .data (), ownedData. data (),
712+ leftHalo. sizes (), leftHalo.strides (), leftHalo.data (), rightHalo.sizes (),
713+ rightHalo.strides (), rightHalo.data (), tc);
672714}
673715
674716extern " C" {
675717#define TYPED_UPDATE_HALO (_sfx, _typ ) \
676- void _idtr_update_halo_##_sfx( \
718+ void * _idtr_update_halo_##_sfx( \
677719 Transceiver *tc, int64_t gShapeRank , void *gShapeDescr , \
678720 int64_t oOffRank, void *oOffDescr, int64_t oDataRank, void *oDataDescr, \
679721 int64_t bbOffRank, void *bbOffDescr, int64_t bbShapeRank, \
680722 void *bbShapeDescr, int64_t lHaloRank, void *lHaloDescr, \
681723 int64_t rHaloRank, void *rHaloDescr) { \
682- _idtr_update_halo<_typ>(tc, gShapeRank , gShapeDescr , oOffRank, oOffDescr, \
683- oDataRank, oDataDescr, bbOffRank, bbOffDescr, \
684- bbShapeRank, bbShapeDescr, lHaloRank, lHaloDescr, \
685- rHaloRank, rHaloDescr); \
724+ return _idtr_update_halo<_typ>( \
725+ tc, gShapeRank , gShapeDescr , oOffRank, oOffDescr, oDataRank, \
726+ oDataDescr, bbOffRank, bbOffDescr, bbShapeRank, bbShapeDescr, \
727+ lHaloRank, lHaloDescr, rHaloRank, rHaloDescr); \
686728 }
687729
688730TYPED_UPDATE_HALO (f64 , double );
0 commit comments