1313#include < cassert>
1414#include < iostream>
1515#include < memory>
16+ #include < unordered_map>
1617
1718constexpr id_t UNKNOWN_GUID = -1 ;
1819
@@ -475,26 +476,64 @@ TYPED_RESHAPE(i1, bool);
475476
476477} // extern "C"
477478
478- // / @brief Update data in halo parts
479- // / We assume tensor is partitioned along the first dimension only
480- // / (row partitioning) and partitions are ordered by ranks
481- // / @return (MPI) handles
482- void *_idtr_update_halo (DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
483- int64_t *ownedShape, int64_t *ownedStride,
484- int64_t *bbOff, int64_t *bbShape, void *ownedData,
485- int64_t *leftHaloShape, int64_t *leftHaloStride,
486- void *leftHaloData, int64_t *rightHaloShape,
487- int64_t *rightHaloStride, void *rightHaloData,
488- Transceiver *tc) {
479+ // struct for caching meta data for update_halo
480+ // no copies allowed, only move-semantics and reference access
481+ struct UHCache {
482+ // copying needed?
483+ std::vector<int64_t > _lBufferStart, _lBufferSize, _rBufferStart, _rBufferSize;
484+ std::vector<int64_t > _lRecvBufferSize, _rRecvBufferSize;
485+ // send maps
486+ std::vector<int > _lSendSize, _rSendSize, _lSendOff, _rSendOff;
487+ // receive maps
488+ std::vector<int > _lRecvSize, _rRecvSize, _lRecvOff, _rRecvOff;
489+ bool _bufferizeSend, _bufferizeLRecv, _bufferizeRRecv;
490+ // start and sizes for chunks from remotes if copies are needed
491+ int64_t _lTotalRecvSize, _rTotalRecvSize, _lTotalSendSize, _rTotalSendSize;
492+
493+ UHCache () = default ;
494+ UHCache (const UHCache &) = delete ;
495+ UHCache (UHCache &&) = default ;
496+ UHCache (std::vector<int64_t > &&lBufferStart,
497+ std::vector<int64_t > &&lBufferSize,
498+ std::vector<int64_t > &&rBufferStart,
499+ std::vector<int64_t > &&rBufferSize,
500+ std::vector<int64_t > &&lRecvBufferSize,
501+ std::vector<int64_t > &&rRecvBufferSize, std::vector<int > &&lSendSize,
502+ std::vector<int > &&rSendSize, std::vector<int > &&lSendOff,
503+ std::vector<int > &&rSendOff, std::vector<int > &&lRecvSize,
504+ std::vector<int > &&rRecvSize, std::vector<int > &&lRecvOff,
505+ std::vector<int > &&rRecvOff, bool bufferizeSend, bool bufferizeLRecv,
506+ bool bufferizeRRecv, int64_t lTotalRecvSize, int64_t rTotalRecvSize,
507+ int64_t lTotalSendSize, int64_t rTotalSendSize)
508+ : _lBufferStart(std::move(lBufferStart)),
509+ _lBufferSize (std::move(lBufferSize)),
510+ _rBufferStart(std::move(rBufferStart)),
511+ _rBufferSize(std::move(rBufferSize)),
512+ _lRecvBufferSize(std::move(lRecvBufferSize)),
513+ _rRecvBufferSize(std::move(rRecvBufferSize)),
514+ _lSendSize(std::move(lSendSize)), _rSendSize(std::move(rSendSize)),
515+ _lSendOff(std::move(lSendOff)), _rSendOff(std::move(rSendOff)),
516+ _lRecvSize(std::move(lRecvSize)), _rRecvSize(std::move(rRecvSize)),
517+ _lRecvOff(std::move(lRecvOff)), _rRecvOff(std::move(rRecvOff)),
518+ _bufferizeSend(bufferizeSend), _bufferizeLRecv(bufferizeLRecv),
519+ _bufferizeRRecv(bufferizeRRecv), _lTotalRecvSize(lTotalRecvSize),
520+ _rTotalRecvSize(rTotalRecvSize), _lTotalSendSize(lTotalSendSize),
521+ _rTotalSendSize(rTotalSendSize) {}
522+ UHCache &operator =(const UHCache &) = delete ;
523+ UHCache &operator =(UHCache &&) = default ;
524+ };
489525
490- #ifdef NO_TRANSCEIVER
491- initMPIRuntime ();
492- tc = getTransceiver ();
493- #endif
494- auto nworkers = tc->nranks ();
495- if (nworkers <= 1 || getenv (" DDPT_SKIP_COMM" ))
496- return nullptr ;
526+ UHCache getMetaData (rank_type nworkers, int64_t ndims, int64_t *ownedOff,
527+ int64_t *ownedShape, int64_t *ownedStride, int64_t *bbOff,
528+ int64_t *bbShape, int64_t *leftHaloShape,
529+ int64_t *leftHaloStride, int64_t *rightHaloShape,
530+ int64_t *rightHaloStride, Transceiver *tc) {
531+ UHCache cE; // holds data if non-cached
497532 auto myWorkerIndex = tc->rank ();
533+ cE._lTotalRecvSize = 0 ;
534+ cE._rTotalRecvSize = 0 ;
535+ cE._lTotalSendSize = 0 ;
536+ cE._rTotalSendSize = 0 ;
498537
499538 // Gather table with bounding box offsets and shapes for all workers
500539 // [ (w0 offsets) o_0, o_1, ..., o_ndims,
@@ -525,18 +564,19 @@ void *_idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
525564
526565 // find local elements to send to next workers (destination leftHalo)
527566 // and previous workers (destination rightHalo)
528- std::vector<int > lSendOff (nworkers, 0 ), rSendOff (nworkers, 0 );
529- std::vector<int > lSendSize (nworkers, 0 ), rSendSize (nworkers, 0 );
567+ cE._lSendOff .resize (nworkers, 0 );
568+ cE._rSendOff .resize (nworkers, 0 );
569+ cE._lSendSize .resize (nworkers, 0 );
570+ cE._rSendSize .resize (nworkers, 0 );
530571
531572 // use send buffer if owned data is strided or sending a subview
532- bool bufferizeSend = (!is_contiguous (ownedShape, ownedStride, ndims) ||
533- bbTotCols != ownedTotCols);
573+ cE. _bufferizeSend = (!is_contiguous (ownedShape, ownedStride, ndims) ||
574+ bbTotCols != ownedTotCols);
534575
535- std::vector<int64_t > lBufferStart (nworkers * ndims, 0 );
536- std::vector<int64_t > lBufferSize (nworkers * ndims, 0 );
537- std::vector<int64_t > rBufferStart (nworkers * ndims, 0 );
538- std::vector<int64_t > rBufferSize (nworkers * ndims, 0 );
539- int64_t lTotalSendSize = 0 , rTotalSendSize = 0 ;
576+ cE._lBufferStart .resize (nworkers * ndims, 0 );
577+ cE._lBufferSize .resize (nworkers * ndims, 0 );
578+ cE._rBufferStart .resize (nworkers * ndims, 0 );
579+ cE._rBufferSize .resize (nworkers * ndims, 0 );
540580
541581 for (auto i = 0 ; i < nworkers; ++i) {
542582 if (i == myWorkerIndex) {
@@ -559,127 +599,170 @@ void *_idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
559599
560600 if (i < myWorkerIndex) {
561601 // target is rightHalo
562- if (bufferizeSend ) {
563- rSendOff [i] = i ? rSendOff [i - 1 ] + rSendSize [i - 1 ] : 0 ;
564- rBufferStart [i * ndims] = localRowStart;
565- rBufferSize [i * ndims] = nRows;
602+ if (cE. _bufferizeSend ) {
603+ cE. _rSendOff [i] = i ? cE. _rSendOff [i - 1 ] + cE. _rSendSize [i - 1 ] : 0 ;
604+ cE. _rBufferStart [i * ndims] = localRowStart;
605+ cE. _rBufferSize [i * ndims] = nRows;
566606 for (auto j = 1 ; j < ndims; ++j) {
567- rBufferStart [i * ndims + j] = bbOff[j];
568- rBufferSize [i * ndims + j] = bbShape[j];
607+ cE. _rBufferStart [i * ndims + j] = bbOff[j];
608+ cE. _rBufferSize [i * ndims + j] = bbShape[j];
569609 }
570610 } else {
571- rSendOff [i] = localStart;
611+ cE. _rSendOff [i] = localStart;
572612 }
573- rSendSize [i] = nSend;
574- rTotalSendSize += nSend;
613+ cE. _rSendSize [i] = nSend;
614+ cE. _rTotalSendSize += nSend;
575615 } else {
576616 // target is leftHalo
577- if (bufferizeSend ) {
578- lSendOff [i] = i ? lSendOff [i - 1 ] + lSendSize [i - 1 ] : 0 ;
579- lBufferStart [i * ndims] = localRowStart;
580- lBufferSize [i * ndims] = nRows;
617+ if (cE. _bufferizeSend ) {
618+ cE. _lSendOff [i] = i ? cE. _lSendOff [i - 1 ] + cE. _lSendSize [i - 1 ] : 0 ;
619+ cE. _lBufferStart [i * ndims] = localRowStart;
620+ cE. _lBufferSize [i * ndims] = nRows;
581621 for (auto j = 1 ; j < ndims; ++j) {
582- lBufferStart [i * ndims + j] = bbOff[j];
583- lBufferSize [i * ndims + j] = bbShape[j];
622+ cE. _lBufferStart [i * ndims + j] = bbOff[j];
623+ cE. _lBufferSize [i * ndims + j] = bbShape[j];
584624 }
585625 } else {
586- lSendOff [i] = localStart;
626+ cE. _lSendOff [i] = localStart;
587627 }
588- lSendSize [i] = nSend;
589- lTotalSendSize += nSend;
628+ cE. _lSendSize [i] = nSend;
629+ cE. _lTotalSendSize += nSend;
590630 }
591631 }
592632 }
593633
594634 // receive maps
595- std::vector<int > lRecvSize (nworkers), rRecvSize (nworkers);
596- std::vector<int > lRecvOff (nworkers), rRecvOff (nworkers);
635+ cE._lRecvSize .resize (nworkers);
636+ cE._rRecvSize .resize (nworkers);
637+ cE._lRecvOff .resize (nworkers);
638+ cE._rRecvOff .resize (nworkers);
639+
597640 // receive size is sender's send size
598- tc->alltoall (lSendSize. data (), 1 , INT32, lRecvSize .data ());
599- tc->alltoall (rSendSize. data (), 1 , INT32, rRecvSize .data ());
641+ tc->alltoall (cE. _lSendSize . data (), 1 , INT32, cE. _lRecvSize .data ());
642+ tc->alltoall (cE. _rSendSize . data (), 1 , INT32, cE. _rRecvSize .data ());
600643 // compute offset in a contiguous receive buffer
601- lRecvOff [0 ] = 0 ;
602- rRecvOff [0 ] = 0 ;
644+ cE. _lRecvOff [0 ] = 0 ;
645+ cE. _rRecvOff [0 ] = 0 ;
603646 for (auto i = 1 ; i < nworkers; ++i) {
604- lRecvOff [i] = lRecvOff [i - 1 ] + lRecvSize [i - 1 ];
605- rRecvOff [i] = rRecvOff [i - 1 ] + rRecvSize [i - 1 ];
647+ cE. _lRecvOff [i] = cE. _lRecvOff [i - 1 ] + cE. _lRecvSize [i - 1 ];
648+ cE. _rRecvOff [i] = cE. _rRecvOff [i - 1 ] + cE. _rRecvSize [i - 1 ];
606649 }
607650
608651 // receive buffering
609- void *lRecvData, *rRecvData;
610- void *sendData;
611- bool bufferizeLRecv = !is_contiguous (leftHaloShape, leftHaloStride, ndims);
612- bool bufferizeRRecv = !is_contiguous (rightHaloShape, rightHaloStride, ndims);
613- std::vector<int64_t > lRecvBufferSize (nworkers * ndims, 0 );
614- std::vector<int64_t > rRecvBufferSize (nworkers * ndims, 0 );
652+ cE._bufferizeLRecv = !is_contiguous (leftHaloShape, leftHaloStride, ndims);
653+ cE._bufferizeRRecv = !is_contiguous (rightHaloShape, rightHaloStride, ndims);
654+ cE._lRecvBufferSize .resize (nworkers * ndims, 0 );
655+ cE._rRecvBufferSize .resize (nworkers * ndims, 0 );
615656
616657 // deduce receive shape for unpack
617- int64_t lTotalRecvSize = 0 , rTotalRecvSize = 0 ;
618658 for (auto i = 0 ; i < nworkers; ++i) {
619- if (bufferizeLRecv && lRecvSize [i] != 0 ) {
620- lTotalRecvSize += lRecvSize [i];
621- lRecvBufferSize [i * ndims] = lRecvSize [i] / bbTotCols; // nrows
659+ if (cE. _bufferizeLRecv && cE. _lRecvSize [i] != 0 ) {
660+ cE. _lTotalRecvSize += cE. _lRecvSize [i];
661+ cE. _lRecvBufferSize [i * ndims] = cE. _lRecvSize [i] / bbTotCols; // nrows
622662 for (auto j = 1 ; j < ndims; ++j) {
623- lRecvBufferSize [i * ndims + j] = bbShape[j]; // leftHaloShape[j]
663+ cE. _lRecvBufferSize [i * ndims + j] = bbShape[j]; // leftHaloShape[j]
624664 }
625665 }
626- if (bufferizeRRecv && rRecvSize [i] != 0 ) {
627- rTotalRecvSize += rRecvSize [i];
628- rRecvBufferSize [i * ndims] = rRecvSize [i] / bbTotCols; // nrows
666+ if (cE. _bufferizeRRecv && cE. _rRecvSize [i] != 0 ) {
667+ cE. _rTotalRecvSize += cE. _rRecvSize [i];
668+ cE. _rRecvBufferSize [i * ndims] = cE. _rRecvSize [i] / bbTotCols; // nrows
629669 for (auto j = 1 ; j < ndims; ++j) {
630- rRecvBufferSize [i * ndims + j] = bbShape[j]; // rightHaloShape[j]
670+ cE. _rRecvBufferSize [i * ndims + j] = bbShape[j]; // rightHaloShape[j]
631671 }
632672 }
633673 }
674+ return cE;
675+ };
634676
635- Buffer recvBuff;
636- Buffer sendBuff;
637- if (bufferizeLRecv || bufferizeLRecv) {
638- recvBuff.resize (std::max (lTotalRecvSize, rTotalRecvSize) *
677+ // / @brief Update data in halo parts
678+ // / We assume tensor is partitioned along the first dimension only
679+ // / (row partitioning) and partitions are ordered by ranks
680+ // / if cache-key is provided (>=0) meta data is read from cache
681+ // / @return (MPI) handles
682+ void *_idtr_update_halo (DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
683+ int64_t *ownedShape, int64_t *ownedStride,
684+ int64_t *bbOff, int64_t *bbShape, void *ownedData,
685+ int64_t *leftHaloShape, int64_t *leftHaloStride,
686+ void *leftHaloData, int64_t *rightHaloShape,
687+ int64_t *rightHaloStride, void *rightHaloData,
688+ Transceiver *tc, int64_t key) {
689+
690+ #ifdef NO_TRANSCEIVER
691+ initMPIRuntime ();
692+ tc = getTransceiver ();
693+ #endif
694+ auto nworkers = tc->nranks ();
695+ if (nworkers <= 1 || getenv (" DDPT_SKIP_COMM" ))
696+ return nullptr ;
697+
698+ // not thread-safe
699+ static std::unordered_map<int64_t , UHCache> uhCache; // meta-data cache
700+ static UHCache *cache = nullptr ; // reading either from non-cached or cached
701+
702+ auto cIt = key == -1 ? uhCache.end () : uhCache.find (key);
703+ if (cIt == uhCache.end ()) { // not in cache
704+ // update cache if requested
705+ cIt = uhCache
706+ .insert_or_assign (
707+ key, std::move (getMetaData (
708+ nworkers, ndims, ownedOff, ownedShape, ownedStride,
709+ bbOff, bbShape, leftHaloShape, leftHaloStride,
710+ rightHaloShape, rightHaloStride, tc)))
711+ .first ;
712+ }
713+ cache = &(cIt->second );
714+
715+ Buffer recvBuff (0 ), sendBuff (0 );
716+ if (cache->_bufferizeLRecv || cache->_bufferizeLRecv ) {
717+ recvBuff.resize (std::max (cache->_lTotalRecvSize , cache->_rTotalRecvSize ) *
639718 sizeof_dtype (ddpttype));
640719 }
641- if (bufferizeSend ) {
642- sendBuff.resize (std::max (lTotalSendSize, rTotalSendSize ) *
720+ if (cache-> _bufferizeSend ) {
721+ sendBuff.resize (std::max (cache-> _lTotalSendSize , cache-> _rTotalSendSize ) *
643722 sizeof_dtype (ddpttype));
644723 }
645- lRecvData = bufferizeLRecv ? recvBuff.data () : leftHaloData;
646- rRecvData = bufferizeRRecv ? recvBuff.data () : rightHaloData;
647- sendData = bufferizeSend ? sendBuff.data () : ownedData;
724+
725+ void *lRecvData = cache->_bufferizeLRecv ? recvBuff.data () : leftHaloData;
726+ void *rRecvData = cache->_bufferizeRRecv ? recvBuff.data () : rightHaloData;
727+ void *sendData = cache->_bufferizeSend ? sendBuff.data () : ownedData;
648728
649729 // communicate left/right halos
650- if (bufferizeSend) {
651- bufferize (ownedData, ddpttype, ownedShape, ownedStride, lBufferStart.data (),
652- lBufferSize.data (), ndims, nworkers, sendBuff.data ());
730+ if (cache->_bufferizeSend ) {
731+ bufferize (ownedData, ddpttype, ownedShape, ownedStride,
732+ cache->_lBufferStart .data (), cache->_lBufferSize .data (), ndims,
733+ nworkers, sendBuff.data ());
653734 }
654- auto lwh = tc->alltoall (sendData, lSendSize.data (), lSendOff.data (), ddpttype,
655- lRecvData, lRecvSize.data (), lRecvOff.data ());
656- if (bufferizeSend) {
657- bufferize (ownedData, ddpttype, ownedShape, ownedStride, rBufferStart.data (),
658- rBufferSize.data (), ndims, nworkers, sendBuff.data ());
735+ auto lwh = tc->alltoall (sendData, cache->_lSendSize .data (),
736+ cache->_lSendOff .data (), ddpttype, lRecvData,
737+ cache->_lRecvSize .data (), cache->_lRecvOff .data ());
738+ if (cache->_bufferizeSend ) {
739+ bufferize (ownedData, ddpttype, ownedShape, ownedStride,
740+ cache->_rBufferStart .data (), cache->_rBufferSize .data (), ndims,
741+ nworkers, sendBuff.data ());
659742 }
660- auto rwh = tc->alltoall (sendData, rSendSize.data (), rSendOff.data (), ddpttype,
661- rRecvData, rRecvSize.data (), rRecvOff.data ());
743+ auto rwh = tc->alltoall (sendData, cache->_rSendSize .data (),
744+ cache->_rSendOff .data (), ddpttype, rRecvData,
745+ cache->_rRecvSize .data (), cache->_rRecvOff .data ());
662746
663- auto wait = [=, _lRecvBufferSize = std::move (lRecvBufferSize),
664- _rRecvBufferSize = std::move (rRecvBufferSize)]() {
747+ auto wait = [=]() {
665748 tc->wait (lwh);
666749 std::vector<int64_t > recvBufferStart (nworkers * ndims, 0 );
667- if (bufferizeLRecv ) {
750+ if (cache-> _bufferizeLRecv ) {
668751 unpack (lRecvData, ddpttype, leftHaloShape, leftHaloStride,
669- recvBufferStart.data (), _lRecvBufferSize.data (), ndims, nworkers ,
670- leftHaloData);
752+ recvBufferStart.data (), cache-> _lRecvBufferSize .data (), ndims,
753+ nworkers, leftHaloData);
671754 }
672755 tc->wait (rwh);
673- if (bufferizeRRecv ) {
756+ if (cache-> _bufferizeRRecv ) {
674757 unpack (rRecvData, ddpttype, rightHaloShape, rightHaloStride,
675- recvBufferStart.data (), _rRecvBufferSize.data (), ndims, nworkers ,
676- rightHaloData);
758+ recvBufferStart.data (), cache-> _rRecvBufferSize .data (), ndims,
759+ nworkers, rightHaloData);
677760 }
678761 };
679- assert (lRecvBufferSize.empty () && rRecvBufferSize.empty ());
680762
681763 // FIXME (in imex) buffer-dealloc pass deallocs halo strides and sizes
682- if (bufferizeLRecv || bufferizeRRecv || getenv (" DDPT_NO_ASYNC" )) {
764+ if (cache->_bufferizeLRecv || cache->_bufferizeRRecv ||
765+ getenv (" DDPT_NO_ASYNC" )) {
683766 wait ();
684767 return nullptr ;
685768 }
@@ -694,7 +777,7 @@ void *_idtr_update_halo(Transceiver *tc, int64_t gShapeRank, void *gShapeDescr,
694777 void *oDataDescr, int64_t bbOffRank, void *bbOffDescr,
695778 int64_t bbShapeRank, void *bbShapeDescr,
696779 int64_t lHaloRank, void *lHaloDescr, int64_t rHaloRank,
697- void *rHaloDescr) {
780+ void *rHaloDescr, int64_t key ) {
698781
699782 auto ddpttype = DTYPE<T>::value;
700783
@@ -710,7 +793,7 @@ void *_idtr_update_halo(Transceiver *tc, int64_t gShapeRank, void *gShapeDescr,
710793 ddpttype, ownedData.rank (), ownedOff.data (), ownedData.sizes (),
711794 ownedData.strides (), bbOff.data (), bbShape.data (), ownedData.data (),
712795 leftHalo.sizes (), leftHalo.strides (), leftHalo.data (), rightHalo.sizes (),
713- rightHalo.strides (), rightHalo.data (), tc);
796+ rightHalo.strides (), rightHalo.data (), tc, key );
714797}
715798
716799extern " C" {
@@ -720,11 +803,11 @@ extern "C" {
720803 int64_t oOffRank, void *oOffDescr, int64_t oDataRank, void *oDataDescr, \
721804 int64_t bbOffRank, void *bbOffDescr, int64_t bbShapeRank, \
722805 void *bbShapeDescr, int64_t lHaloRank, void *lHaloDescr, \
723- int64_t rHaloRank, void *rHaloDescr) { \
806+ int64_t rHaloRank, void *rHaloDescr, int64_t key) { \
724807 return _idtr_update_halo<_typ>( \
725808 tc, gShapeRank , gShapeDescr , oOffRank, oOffDescr, oDataRank, \
726809 oDataDescr, bbOffRank, bbOffDescr, bbShapeRank, bbShapeDescr, \
727- lHaloRank, lHaloDescr, rHaloRank, rHaloDescr); \
810+ lHaloRank, lHaloDescr, rHaloRank, rHaloDescr, key); \
728811 }
729812
730813TYPED_UPDATE_HALO (f64 , double );
0 commit comments