@@ -51,14 +51,30 @@ template <typename T> WaitHandle<T> *mkWaitHandle(T fini) {
5151 return new WaitHandle<T>(fini);
5252};
5353
54- extern " C " {
55- void _idtr_wait (WaitHandleBase *handle ) {
54+ void _idtr_wait (WaitHandleBase *handle, int64_t lHaloRank, void *lHaloDescr,
55+ int64_t rHaloRank, void *rHaloDescr ) {
5656 if (handle) {
5757 handle->wait ();
5858 delete handle;
5959 }
6060}
6161
62+ extern " C" {
63+ #define TYPED_WAIT (_sfx ) \
64+ void _idtr_wait_##_sfx(WaitHandleBase *handle, int64_t lHaloRank, \
65+ void *lHaloDescr, int64_t rHaloRank, \
66+ void *rHaloDescr) { \
67+ return _idtr_wait (handle, lHaloRank, lHaloDescr, rHaloRank, rHaloDescr); \
68+ }
69+
70+ TYPED_WAIT (f64 );
71+ TYPED_WAIT (f32 );
72+ TYPED_WAIT (i64 );
73+ TYPED_WAIT (i32 );
74+ TYPED_WAIT (i16 );
75+ TYPED_WAIT (i8 );
76+ TYPED_WAIT (i1);
77+
6278#define NO_TRANSCEIVER
6379#ifdef NO_TRANSCEIVER
6480static void initMPIRuntime () {
@@ -486,6 +502,8 @@ struct UHCache {
486502 std::vector<int > _lSendSize, _rSendSize, _lSendOff, _rSendOff;
487503 // receive maps
488504 std::vector<int > _lRecvSize, _rRecvSize, _lRecvOff, _rRecvOff;
505+ // buffers
506+ Buffer _recvBuff, _sendLBuff, _sendRBuff;
489507 bool _bufferizeSend, _bufferizeLRecv, _bufferizeRRecv;
490508 // start and sizes for chunks from remotes if copies are needed
491509 int64_t _lTotalRecvSize, _rTotalRecvSize, _lTotalSendSize, _rTotalSendSize;
@@ -502,6 +520,7 @@ struct UHCache {
502520 std::vector<int > &&rSendSize, std::vector<int > &&lSendOff,
503521 std::vector<int > &&rSendOff, std::vector<int > &&lRecvSize,
504522 std::vector<int > &&rRecvSize, std::vector<int > &&lRecvOff,
523+ Buffer &&recvBuff, Buffer &&sendLBuff, Buffer &&sendRBuff,
505524 std::vector<int > &&rRecvOff, bool bufferizeSend, bool bufferizeLRecv,
506525 bool bufferizeRRecv, int64_t lTotalRecvSize, int64_t rTotalRecvSize,
507526 int64_t lTotalSendSize, int64_t rTotalSendSize)
@@ -515,10 +534,11 @@ struct UHCache {
515534 _lSendOff(std::move(lSendOff)), _rSendOff(std::move(rSendOff)),
516535 _lRecvSize(std::move(lRecvSize)), _rRecvSize(std::move(rRecvSize)),
517536 _lRecvOff(std::move(lRecvOff)), _rRecvOff(std::move(rRecvOff)),
518- _bufferizeSend(bufferizeSend), _bufferizeLRecv(bufferizeLRecv),
519- _bufferizeRRecv(bufferizeRRecv), _lTotalRecvSize(lTotalRecvSize),
520- _rTotalRecvSize(rTotalRecvSize), _lTotalSendSize(lTotalSendSize),
521- _rTotalSendSize(rTotalSendSize) {}
537+ _recvBuff(std::move(recvBuff)), _sendLBuff(std::move(sendLBuff)),
538+ _sendRBuff(std::move(sendRBuff)), _bufferizeSend(bufferizeSend),
539+ _bufferizeLRecv(bufferizeLRecv), _bufferizeRRecv(bufferizeRRecv),
540+ _lTotalRecvSize(lTotalRecvSize), _rTotalRecvSize(rTotalRecvSize),
541+ _lTotalSendSize(lTotalSendSize), _rTotalSendSize(rTotalSendSize) {}
522542 UHCache &operator =(const UHCache &) = delete ;
523543 UHCache &operator =(UHCache &&) = default ;
524544};
@@ -712,35 +732,40 @@ void *_idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
712732 }
713733 cache = &(cIt->second );
714734
715- Buffer recvBuff (0 ), sendBuff (0 );
716735 if (cache->_bufferizeLRecv || cache->_bufferizeRRecv ) {
717- recvBuff.resize (std::max (cache->_lTotalRecvSize , cache->_rTotalRecvSize ) *
718- sizeof_dtype (ddpttype));
736+ cache->_recvBuff .resize (
737+ std::max (cache->_lTotalRecvSize , cache->_rTotalRecvSize ) *
738+ sizeof_dtype (ddpttype));
719739 }
720740 if (cache->_bufferizeSend ) {
721- sendBuff .resize (std::max ( cache->_lTotalSendSize , cache-> _rTotalSendSize ) *
722- sizeof_dtype (ddpttype));
741+ cache-> _sendLBuff .resize (cache->_lTotalSendSize * sizeof_dtype (ddpttype));
742+ cache-> _sendRBuff . resize (cache-> _rTotalSendSize * sizeof_dtype (ddpttype));
723743 }
724744
725- void *lRecvData = cache->_bufferizeLRecv ? recvBuff.data () : leftHaloData;
726- void *rRecvData = cache->_bufferizeRRecv ? recvBuff.data () : rightHaloData;
727- void *sendData = cache->_bufferizeSend ? sendBuff.data () : ownedData;
745+ void *lRecvData =
746+ cache->_bufferizeLRecv ? cache->_recvBuff .data () : leftHaloData;
747+ void *rRecvData =
748+ cache->_bufferizeRRecv ? cache->_recvBuff .data () : rightHaloData;
749+ void *lSendData =
750+ cache->_bufferizeSend ? cache->_sendLBuff .data () : ownedData;
751+ void *rSendData =
752+ cache->_bufferizeSend ? cache->_sendRBuff .data () : ownedData;
728753
729754 // communicate left/right halos
730755 if (cache->_bufferizeSend ) {
731756 bufferize (ownedData, ddpttype, ownedShape, ownedStride,
732757 cache->_lBufferStart .data (), cache->_lBufferSize .data (), ndims,
733- nworkers, sendBuff .data ());
758+ nworkers, cache-> _sendLBuff .data ());
734759 }
735- auto lwh = tc->alltoall (sendData , cache->_lSendSize .data (),
760+ auto lwh = tc->alltoall (lSendData , cache->_lSendSize .data (),
736761 cache->_lSendOff .data (), ddpttype, lRecvData,
737762 cache->_lRecvSize .data (), cache->_lRecvOff .data ());
738763 if (cache->_bufferizeSend ) {
739764 bufferize (ownedData, ddpttype, ownedShape, ownedStride,
740765 cache->_rBufferStart .data (), cache->_rBufferSize .data (), ndims,
741- nworkers, sendBuff .data ());
766+ nworkers, cache-> _sendRBuff .data ());
742767 }
743- auto rwh = tc->alltoall (sendData , cache->_rSendSize .data (),
768+ auto rwh = tc->alltoall (rSendData , cache->_rSendSize .data (),
744769 cache->_rSendOff .data (), ddpttype, rRecvData,
745770 cache->_rRecvSize .data (), cache->_rRecvOff .data ());
746771
@@ -760,7 +785,6 @@ void *_idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
760785 }
761786 };
762787
763- // FIXME (in imex) buffer-dealloc pass deallocs halo strides and sizes
764788 if (cache->_bufferizeLRecv || cache->_bufferizeRRecv ||
765789 getenv (" DDPT_NO_ASYNC" )) {
766790 wait ();
0 commit comments