Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 5c971f9

Browse files
authored
Overlap comm and compute (#48)
* adding overlap-comm-and-compute pass * implementing async comm support
1 parent 619aa60 commit 5c971f9

File tree

7 files changed

+116
-60
lines changed

7 files changed

+116
-60
lines changed

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e5d49d469ec53f62d5befc7c30f4f404a7870031
1+
ff9d13d5e71213a99a9a436895072b853236c040

src/Deferred.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
#include "include/ddptensor/Service.hpp"
1414
#include "include/ddptensor/Transceiver.hpp"
1515

16-
#include <imex/Dialect/Dist/IR/DistOps.h>
17-
#include <imex/Dialect/DistRuntime/IR/DistRuntimeOps.h>
1816
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
1917
#include <mlir/Dialect/Func/IR/FuncOps.h>
2018
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
@@ -95,8 +93,6 @@ void process_promises() {
9593

9694
// Create a MLIR module
9795
auto module = builder.create<::mlir::ModuleOp>(loc);
98-
auto protos = builder.create<::imex::distruntime::RuntimePrototypesOp>(loc);
99-
module.push_back(protos);
10096
// Create the jit func
10197
// create dummy type, we'll replace it with the actual type later
10298
auto dummyFuncType = builder.getFunctionType({}, {});

src/MPITransceiver.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,17 @@ void MPITransceiver::reduce_all(void *inout, DTypeId T, size_t N,
187187
MPI_Allreduce(MPI_IN_PLACE, inout, N, to_mpi(T), to_mpi(op), _comm);
188188
}
189189

190-
void MPITransceiver::alltoall(const void *buffer_send, const int *counts_send,
191-
const int *displacements_send, DTypeId datatype,
192-
void *buffer_recv, const int *counts_recv,
193-
const int *displacements_recv) {
194-
MPI_Alltoallv(buffer_send, counts_send, displacements_send, to_mpi(datatype),
195-
buffer_recv, counts_recv, displacements_recv, to_mpi(datatype),
196-
_comm);
190+
Transceiver::WaitHandle
191+
MPITransceiver::alltoall(const void *buffer_send, const int *counts_send,
192+
const int *displacements_send, DTypeId datatype,
193+
void *buffer_recv, const int *counts_recv,
194+
const int *displacements_recv) {
195+
MPI_Request request;
196+
MPI_Ialltoallv(buffer_send, counts_send, displacements_send, to_mpi(datatype),
197+
buffer_recv, counts_recv, displacements_recv, to_mpi(datatype),
198+
_comm, &request);
199+
static_assert(sizeof(request == sizeof(WaitHandle)));
200+
return static_cast<WaitHandle>(request);
197201
}
198202

199203
void MPITransceiver::alltoall(const void *buffer_send, const int counts,
@@ -226,3 +230,10 @@ void MPITransceiver::send_recv(void *buffer_send, int count_send,
226230
MPI_Sendrecv_replace(buffer_send, count_send, to_mpi(datatype_send), dest,
227231
SRTAG, source, SRTAG, _comm, MPI_STATUS_IGNORE);
228232
}
233+
234+
void MPITransceiver::wait(WaitHandle h) {
235+
if (h) {
236+
auto r = static_cast<MPI_Request>(h);
237+
MPI_Wait(&r, MPI_STATUS_IGNORE);
238+
}
239+
}

src/idtr.cpp

Lines changed: 82 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,33 @@ template <typename T> T *mr_to_ptr(void *ptr, intptr_t offset) {
3030
return reinterpret_cast<T *>(ptr) + offset; // &mr.aligned[mr.offset]
3131
}
3232

33+
// abstract handle providing an abstract wait method
34+
struct WaitHandleBase {
35+
virtual ~WaitHandleBase(){};
36+
virtual void wait() = 0;
37+
};
38+
39+
// concrete handle to be instantiated with a lambda or alike
40+
// the lambda will be executed within wait()
41+
template <typename T> class WaitHandle : public WaitHandleBase {
42+
T _fini;
43+
44+
public:
45+
WaitHandle(T fini) : _fini(fini) {}
46+
virtual void wait() override { _fini(); }
47+
};
48+
49+
template <typename T> WaitHandle<T> *mkWaitHandle(T fini) {
50+
return new WaitHandle<T>(fini);
51+
};
52+
3353
extern "C" {
54+
void _idtr_wait(WaitHandleBase *handle) {
55+
if (handle) {
56+
handle->wait();
57+
delete handle;
58+
}
59+
}
3460

3561
#define NO_TRANSCEIVER
3662
#ifdef NO_TRANSCEIVER
@@ -401,8 +427,9 @@ void _idtr_reshape(DTypeId ddpttype, int64_t lRank, int64_t *gShapePtr,
401427
Buffer outbuff(totSSz * sizeof_dtype(ddpttype), 2); // FIXME debug value
402428
bufferizeN(lDataPtr, ddpttype, lShapePtr, lStridesPtr, lsOffs.data(),
403429
lsEnds.data(), lRank, N, outbuff.data());
404-
tc->alltoall(outbuff.data(), sszs.data(), soffs.data(), ddpttype, oDataPtr,
405-
rszs.data(), roffs.data());
430+
auto hdl = tc->alltoall(outbuff.data(), sszs.data(), soffs.data(), ddpttype,
431+
oDataPtr, rszs.data(), roffs.data());
432+
tc->wait(hdl);
406433
}
407434

408435
/// @brief reshape tensor
@@ -451,21 +478,22 @@ TYPED_RESHAPE(i1, bool);
451478
/// @brief Update data in halo parts
452479
/// We assume tensor is partitioned along the first dimension only
453480
/// (row partitioning) and partitions are ordered by ranks
454-
void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
455-
int64_t *ownedShape, int64_t *ownedStride,
456-
int64_t *bbOff, int64_t *bbShape, void *ownedData,
457-
int64_t *leftHaloShape, int64_t *leftHaloStride,
458-
void *leftHaloData, int64_t *rightHaloShape,
459-
int64_t *rightHaloStride, void *rightHaloData,
460-
Transceiver *tc) {
481+
/// @return (MPI) handles
482+
void *_idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
483+
int64_t *ownedShape, int64_t *ownedStride,
484+
int64_t *bbOff, int64_t *bbShape, void *ownedData,
485+
int64_t *leftHaloShape, int64_t *leftHaloStride,
486+
void *leftHaloData, int64_t *rightHaloShape,
487+
int64_t *rightHaloStride, void *rightHaloData,
488+
Transceiver *tc) {
461489

462490
#ifdef NO_TRANSCEIVER
463491
initMPIRuntime();
464492
tc = getTransceiver();
465493
#endif
466494
auto nworkers = tc->nranks();
467495
if (nworkers <= 1 || getenv("DDPT_SKIP_COMM"))
468-
return;
496+
return nullptr;
469497
auto myWorkerIndex = tc->rank();
470498

471499
// Gather table with bounding box offsets and shapes for all workers
@@ -582,7 +610,6 @@ void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
582610
void *sendData;
583611
bool bufferizeLRecv = !is_contiguous(leftHaloShape, leftHaloStride, ndims);
584612
bool bufferizeRRecv = !is_contiguous(rightHaloShape, rightHaloStride, ndims);
585-
std::vector<int64_t> recvBufferStart(nworkers * ndims, 0);
586613
std::vector<int64_t> lRecvBufferSize(nworkers * ndims, 0);
587614
std::vector<int64_t> rRecvBufferSize(nworkers * ndims, 0);
588615

@@ -624,35 +651,50 @@ void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
624651
bufferize(ownedData, ddpttype, ownedShape, ownedStride, lBufferStart.data(),
625652
lBufferSize.data(), ndims, nworkers, sendBuff.data());
626653
}
627-
tc->alltoall(sendData, lSendSize.data(), lSendOff.data(), ddpttype, lRecvData,
628-
lRecvSize.data(), lRecvOff.data());
629-
if (bufferizeLRecv) {
630-
unpack(lRecvData, ddpttype, leftHaloShape, leftHaloStride,
631-
recvBufferStart.data(), lRecvBufferSize.data(), ndims, nworkers,
632-
leftHaloData);
633-
}
654+
auto lwh = tc->alltoall(sendData, lSendSize.data(), lSendOff.data(), ddpttype,
655+
lRecvData, lRecvSize.data(), lRecvOff.data());
634656
if (bufferizeSend) {
635657
bufferize(ownedData, ddpttype, ownedShape, ownedStride, rBufferStart.data(),
636658
rBufferSize.data(), ndims, nworkers, sendBuff.data());
637659
}
638-
tc->alltoall(sendData, rSendSize.data(), rSendOff.data(), ddpttype, rRecvData,
639-
rRecvSize.data(), rRecvOff.data());
640-
if (bufferizeRRecv) {
641-
unpack(rRecvData, ddpttype, rightHaloShape, rightHaloStride,
642-
recvBufferStart.data(), rRecvBufferSize.data(), ndims, nworkers,
643-
rightHaloData);
660+
auto rwh = tc->alltoall(sendData, rSendSize.data(), rSendOff.data(), ddpttype,
661+
rRecvData, rRecvSize.data(), rRecvOff.data());
662+
663+
auto wait = [=, _lRecvBufferSize = std::move(lRecvBufferSize),
664+
_rRecvBufferSize = std::move(rRecvBufferSize)]() {
665+
tc->wait(lwh);
666+
std::vector<int64_t> recvBufferStart(nworkers * ndims, 0);
667+
if (bufferizeLRecv) {
668+
unpack(lRecvData, ddpttype, leftHaloShape, leftHaloStride,
669+
recvBufferStart.data(), _lRecvBufferSize.data(), ndims, nworkers,
670+
leftHaloData);
671+
}
672+
tc->wait(rwh);
673+
if (bufferizeRRecv) {
674+
unpack(rRecvData, ddpttype, rightHaloShape, rightHaloStride,
675+
recvBufferStart.data(), _rRecvBufferSize.data(), ndims, nworkers,
676+
rightHaloData);
677+
}
678+
};
679+
assert(lRecvBufferSize.empty() && rRecvBufferSize.empty());
680+
681+
// FIXME (in imex) buffer-dealloc pass deallocs halo strides and sizes
682+
if (bufferizeLRecv || bufferizeRRecv || getenv("DDPT_NO_ASYNC")) {
683+
wait();
684+
return nullptr;
644685
}
686+
return mkWaitHandle(wait);
645687
}
646688

647689
/// @brief templated wrapper for typed function versions calling
648690
/// _idtr_update_halo
649691
template <typename T>
650-
void _idtr_update_halo(Transceiver *tc, int64_t gShapeRank, void *gShapeDescr,
651-
int64_t oOffRank, void *oOffDescr, int64_t oDataRank,
652-
void *oDataDescr, int64_t bbOffRank, void *bbOffDescr,
653-
int64_t bbShapeRank, void *bbShapeDescr,
654-
int64_t lHaloRank, void *lHaloDescr, int64_t rHaloRank,
655-
void *rHaloDescr) {
692+
void *_idtr_update_halo(Transceiver *tc, int64_t gShapeRank, void *gShapeDescr,
693+
int64_t oOffRank, void *oOffDescr, int64_t oDataRank,
694+
void *oDataDescr, int64_t bbOffRank, void *bbOffDescr,
695+
int64_t bbShapeRank, void *bbShapeDescr,
696+
int64_t lHaloRank, void *lHaloDescr, int64_t rHaloRank,
697+
void *rHaloDescr) {
656698

657699
auto ddpttype = DTYPE<T>::value;
658700

@@ -664,25 +706,25 @@ void _idtr_update_halo(Transceiver *tc, int64_t gShapeRank, void *gShapeDescr,
664706
UnrankedMemRefType<T> leftHalo(lHaloRank, lHaloDescr);
665707
UnrankedMemRefType<T> rightHalo(rHaloRank, rHaloDescr);
666708

667-
_idtr_update_halo(ddpttype, ownedData.rank(), ownedOff.data(),
668-
ownedData.sizes(), ownedData.strides(), bbOff.data(),
669-
bbShape.data(), ownedData.data(), leftHalo.sizes(),
670-
leftHalo.strides(), leftHalo.data(), rightHalo.sizes(),
671-
rightHalo.strides(), rightHalo.data(), tc);
709+
return _idtr_update_halo(
710+
ddpttype, ownedData.rank(), ownedOff.data(), ownedData.sizes(),
711+
ownedData.strides(), bbOff.data(), bbShape.data(), ownedData.data(),
712+
leftHalo.sizes(), leftHalo.strides(), leftHalo.data(), rightHalo.sizes(),
713+
rightHalo.strides(), rightHalo.data(), tc);
672714
}
673715

674716
extern "C" {
675717
#define TYPED_UPDATE_HALO(_sfx, _typ) \
676-
void _idtr_update_halo_##_sfx( \
718+
void *_idtr_update_halo_##_sfx( \
677719
Transceiver *tc, int64_t gShapeRank, void *gShapeDescr, \
678720
int64_t oOffRank, void *oOffDescr, int64_t oDataRank, void *oDataDescr, \
679721
int64_t bbOffRank, void *bbOffDescr, int64_t bbShapeRank, \
680722
void *bbShapeDescr, int64_t lHaloRank, void *lHaloDescr, \
681723
int64_t rHaloRank, void *rHaloDescr) { \
682-
_idtr_update_halo<_typ>(tc, gShapeRank, gShapeDescr, oOffRank, oOffDescr, \
683-
oDataRank, oDataDescr, bbOffRank, bbOffDescr, \
684-
bbShapeRank, bbShapeDescr, lHaloRank, lHaloDescr, \
685-
rHaloRank, rHaloDescr); \
724+
return _idtr_update_halo<_typ>( \
725+
tc, gShapeRank, gShapeDescr, oOffRank, oOffDescr, oDataRank, \
726+
oDataDescr, bbOffRank, bbOffDescr, bbShapeRank, bbShapeDescr, \
727+
lHaloRank, lHaloDescr, rHaloRank, rHaloDescr); \
686728
}
687729

688730
TYPED_UPDATE_HALO(f64, double);

src/include/ddptensor/MPITransceiver.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,18 @@ class MPITransceiver : public Transceiver {
2727
virtual void barrier();
2828
virtual void bcast(void *ptr, size_t N, rank_type root);
2929
virtual void reduce_all(void *inout, DTypeId T, size_t N, RedOpType op);
30-
virtual void alltoall(const void *buffer_send, const int *counts_send,
31-
const int *displacements_send, DTypeId datatype_send,
32-
void *buffer_recv, const int *counts_recv,
33-
const int *displacements_recv);
30+
virtual WaitHandle alltoall(const void *buffer_send, const int *counts_send,
31+
const int *displacements_send,
32+
DTypeId datatype_send, void *buffer_recv,
33+
const int *counts_recv,
34+
const int *displacements_recv);
3435
virtual void alltoall(const void *buffer_send, const int counts,
3536
DTypeId datatype, void *buffer_recv);
3637
virtual void gather(void *buffer, const int *counts, const int *displacements,
3738
DTypeId datatype, rank_type root);
3839
virtual void send_recv(void *buffer_send, int count_send,
3940
DTypeId datatype_send, int dest, int source);
41+
virtual void wait(WaitHandle);
4042

4143
private:
4244
rank_type _nranks, _rank;

src/include/ddptensor/Transceiver.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
class Transceiver {
1212
public:
13+
using WaitHandle = uint32_t;
14+
1315
virtual ~Transceiver(){};
1416

1517
virtual bool is_cw() = 0;
@@ -38,10 +40,11 @@ class Transceiver {
3840
virtual void reduce_all(void *inout, DTypeId T, size_t N, RedOpType op) = 0;
3941

4042
// umm, can this be higher-level?
41-
virtual void alltoall(const void *buffer_send, const int *counts_send,
42-
const int *displacements_send, DTypeId datatype_send,
43-
void *buffer_recv, const int *counts_recv,
44-
const int *displacements_recv) = 0;
43+
virtual WaitHandle alltoall(const void *buffer_send, const int *counts_send,
44+
const int *displacements_send,
45+
DTypeId datatype_send, void *buffer_recv,
46+
const int *counts_recv,
47+
const int *displacements_recv) = 0;
4548
virtual void alltoall(const void *buffer_send, const int counts,
4649
DTypeId datatype, void *buffer_recv) = 0;
4750

@@ -50,6 +53,7 @@ class Transceiver {
5053

5154
virtual void send_recv(void *buffer_send, int count_send,
5255
DTypeId datatype_send, int dest, int source) = 0;
56+
virtual void wait(WaitHandle) = 0;
5357
};
5458

5559
extern void init_transceiver(Transceiver *);

src/jit/mlir.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ static const char *pass_pipeline =
435435
"func.func(dist-infer-elementwise-cores),"
436436
"convert-dist-to-standard,"
437437
"canonicalize,"
438+
"overlap-comm-and-compute,"
438439
"lower-distruntime-to-idtr,"
439440
"convert-ptensor-to-linalg,"
440441
"canonicalize,"

0 commit comments

Comments
 (0)