Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 2a940b1

Browse files
fschlimbtkarna
andauthored
Distcoalesce2 (#42)
* enable distcoalesce * fix update_halo for general subview case * updating imex sha --------- Co-authored-by: Tuomas Karna <tuomas.karna@intel.com>
1 parent c43c05f commit 2a940b1

File tree

5 files changed

+44
-16
lines changed

5 files changed

+44
-16
lines changed

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
89b5d56c4774ddb82ab8f896c3d977c6edae267b
1+
571f54577e2301c70033fef9a05b8a96fa841d2b

src/idtr.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -464,16 +464,21 @@ void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
464464
auto ownedRows = ownedShape[0];
465465
auto ownedRowEnd = ownedRowStart + ownedRows;
466466
// all remaining dims are treated as one large column
467-
auto ownedCols = std::accumulate(&ownedShape[1], &ownedShape[ndims], 1,
467+
auto ownedTotCols = std::accumulate(&ownedShape[1], &ownedShape[ndims], 1,
468+
std::multiplies<int64_t>());
469+
auto bbTotCols = std::accumulate(&bbShape[1], &bbShape[ndims], 1,
468470
std::multiplies<int64_t>());
469471

470472
// find local elements to send to next workers (destination leftHalo)
471473
// and previous workers (destination rightHalo)
472474
std::vector<int> lSendOff(nworkers, 0), rSendOff(nworkers, 0);
473475
std::vector<int> lSendSize(nworkers, 0), rSendSize(nworkers, 0);
474476

475-
// use send buffer if owned data is strided
476-
bool bufferizeSend = !is_contiguous(ownedShape, ownedStride, ndims);
477+
// use send buffer if owned data is strided or sending a subview
478+
bool bufferizeSend = (!is_contiguous(ownedShape, ownedStride, ndims) ||
479+
bbTotCols != ownedTotCols);
480+
481+
// assert(!bufferizeSend);
477482
std::vector<int64_t> lBufferStart(nworkers * ndims, 0);
478483
std::vector<int64_t> lBufferSize(nworkers * ndims, 0);
479484
std::vector<int64_t> rBufferStart(nworkers * ndims, 0);
@@ -495,9 +500,9 @@ void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
495500
auto globalRowStart = std::max(ownedRowStart, bRowStart);
496501
auto globalRowEnd = std::min(ownedRowEnd, bRowEnd);
497502
auto localRowStart = globalRowStart - ownedRowStart;
498-
auto localStart = (int)(localRowStart)*ownedCols;
503+
auto localStart = (int)(localRowStart)*ownedTotCols;
499504
auto nRows = globalRowEnd - globalRowStart;
500-
auto nSend = (int)(nRows)*ownedCols;
505+
auto nSend = (int)(nRows)*bbTotCols;
501506

502507
if (i < myWorkerIndex) {
503508
// target is rightHalo
@@ -506,8 +511,8 @@ void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
506511
rBufferStart[i * ndims] = localRowStart;
507512
rBufferSize[i * ndims] = nRows;
508513
for (auto j = 1; j < ndims; ++j) {
509-
rBufferStart[i * ndims + j] = ownedOff[j];
510-
rBufferSize[i * ndims + j] = ownedShape[j];
514+
rBufferStart[i * ndims + j] = bbOff[j];
515+
rBufferSize[i * ndims + j] = bbShape[j];
511516
}
512517
} else {
513518
rSendOff[i] = localStart;
@@ -521,8 +526,8 @@ void _idtr_update_halo(DTypeId ddpttype, int64_t ndims, int64_t *ownedOff,
521526
lBufferStart[i * ndims] = localRowStart;
522527
lBufferSize[i * ndims] = nRows;
523528
for (auto j = 1; j < ndims; ++j) {
524-
lBufferStart[i * ndims + j] = ownedOff[j];
525-
lBufferSize[i * ndims + j] = ownedShape[j];
529+
lBufferStart[i * ndims + j] = bbOff[j];
530+
lBufferSize[i * ndims + j] = bbShape[j];
526531
}
527532
} else {
528533
lSendOff[i] = localStart;

src/jit/mlir.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ JIT::createExecutionEngine(::mlir::ModuleOp &module) {
431431
static const char *pass_pipeline =
432432
getenv("DDPT_PASSES") ? getenv("DDPT_PASSES")
433433
: "func.func(ptensor-dist),"
434-
// "func.func(dist-coalesce)," FIXME
434+
"func.func(dist-coalesce),"
435435
"convert-dist-to-standard,"
436436
"convert-ptensor-to-linalg,"
437437
"canonicalize,"
@@ -452,7 +452,7 @@ static const char *pass_pipeline =
452452
"func.func(linalg-detensorize),"
453453
"func.func(tensor-bufferize),"
454454
"func.func(finalizing-bufferize),"
455-
// "func.func(buffer-deallocation)," FIXME
455+
"func.func(buffer-deallocation),"
456456
// "imex-remove-temporaries," FIXME
457457
"func.func(convert-linalg-to-parallel-loops),"
458458
"func.func(scf-parallel-loop-fusion),"
@@ -491,9 +491,10 @@ JIT::JIT()
491491
}
492492
// some verbosity
493493
if (_verbose) {
494-
std::cerr << "pass pipeline: " << pass_pipeline << std::endl;
494+
std::cerr << "DDPT_PASSES=\"" << pass_pipeline << "\"" << std::endl;
495495
// _pm.enableStatistics();
496-
_pm.enableTiming();
496+
if (_verbose > 2)
497+
_pm.enableTiming();
497498
// if(_verbose > 1)
498499
// _pm.dump();
499500
if (_verbose > 3)

test/test_setget.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def doit(aapi):
2626
def test_setitem2(self):
2727
def doit(aapi):
2828
a = aapi.ones((16, 16), aapi.float64)
29-
b = aapi.zeros((16, 16), aapi.float64)
29+
b = aapi.fromfunction(lambda i, j: 10 * i + j, (16, 16), dtype=aapi.float64)
3030
a[1:8, 0:6] = b[0:7, 0:6]
3131
return a
3232

@@ -35,12 +35,31 @@ def doit(aapi):
3535
def test_setitem3(self):
3636
def doit(aapi):
3737
a = aapi.ones((16, 16), aapi.float64)
38-
b = aapi.zeros((16, 16), aapi.float64)
38+
b = aapi.fromfunction(lambda i, j: 10 * i + j, (16, 16), dtype=aapi.float64)
3939
a[7:16:3, 4:10:2] = b[4:7, 10:16:2]
4040
return a
4141

4242
assert runAndCompare(doit)
4343

44+
def test_setitem4(self):
45+
# Note: test halo update without send buffer
46+
def doit(aapi):
47+
a = aapi.ones((16, 16), aapi.float64)
48+
b = aapi.fromfunction(lambda i, j: 10 * i + j, (16, 16), dtype=aapi.float64)
49+
a[7:16:3, 0:16] = b[4:7, 0:16]
50+
return a
51+
52+
assert runAndCompare(doit)
53+
54+
def test_setitem5(self):
55+
# Note: test assignment to one full local part
56+
def doit(aapi):
57+
a = aapi.fromfunction(lambda i, j: 10 * i + j, (16, 16), dtype=aapi.int64)
58+
a[0:10, 4:11] = a[0:10, 4:11]
59+
return a
60+
61+
assert runAndCompare(doit)
62+
4463
def test_colon(self):
4564
a = dt.ones((16, 16), dt.float64)
4665
b = dt.zeros((16, 16), dt.float64)

test/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import numpy
22
import ddptensor
3+
from ddptensor.numpy import fromfunction
4+
5+
ddptensor.fromfunction = fromfunction
36

47

58
def runAndCompare(func, do_gather=True):

0 commit comments

Comments
 (0)