1010#include " ddptensor/TypeDispatch.hpp"
1111#include " ddptensor/jit/mlir.hpp"
1212
13+ #include < imex/Dialect/Dist/IR/DistOps.h>
1314#include < imex/Dialect/PTensor/IR/PTensorOps.h>
1415#include < imex/Utils/PassUtils.h>
1516
@@ -35,8 +36,8 @@ struct DeferredFull : public Deferred {
3536
3637 DeferredFull () = default ;
3738 DeferredFull (const shape_type &shape, PyScalar val, DTypeId dtype,
38- uint64_t team)
39- : Deferred(dtype, shape, team, true ), _val(val) {}
39+ const std::string &device, uint64_t team)
40+ : Deferred(dtype, shape, device, team ), _val(val) {}
4041
4142 template <typename T> struct ValAndDType {
4243 static ::mlir::Value op (::mlir::OpBuilder &builder,
@@ -67,35 +68,27 @@ struct DeferredFull : public Deferred {
6768
6869 ::imex::ptensor::DType dtyp;
6970 ::mlir::Value val = dispatch<ValAndDType>(_dtype, builder, loc, _val, dtyp);
70-
71- auto transceiver = getTransceiver ();
72- auto teamV = team () == 0
73- ? ::mlir::Value ()
74- : ::imex::createIndex (loc, builder,
75- reinterpret_cast <uint64_t >(team ()));
76-
77- auto rTyp = ::imex::ptensor::PTensorType::get (
78- shape (), imex::ptensor::toMLIR (builder, dtyp));
79-
80- dm.addVal (this ->guid (),
81- builder.create <::imex::ptensor::CreateOp>(loc, rTyp, shp, dtyp,
82- val, nullptr , teamV),
83- [this ](uint64_t rank, void *l_allocated, void *l_aligned,
84- intptr_t l_offset, const intptr_t *l_sizes,
85- const intptr_t *l_strides, void *o_allocated,
86- void *o_aligned, intptr_t o_offset,
87- const intptr_t *o_sizes, const intptr_t *o_strides,
88- void *r_allocated, void *r_aligned, intptr_t r_offset,
89- const intptr_t *r_sizes, const intptr_t *r_strides,
90- uint64_t *lo_allocated, uint64_t *lo_aligned) {
91- assert (rank == this ->rank ());
92- this ->set_value (std::move (mk_tnsr (
93- reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
94- this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
95- l_strides, o_allocated, o_aligned, o_offset, o_sizes,
96- o_strides, r_allocated, r_aligned, r_offset, r_sizes,
97- r_strides, lo_allocated, lo_aligned)));
98- });
71+ auto envs = jit::mkEnvs (builder, rank (), _device, team ());
72+
73+ dm.addVal (
74+ this ->guid (),
75+ builder.create <::imex::ptensor::CreateOp>(loc, shp, dtyp, val, envs),
76+ [this ](uint64_t rank, void *l_allocated, void *l_aligned,
77+ intptr_t l_offset, const intptr_t *l_sizes,
78+ const intptr_t *l_strides, void *o_allocated, void *o_aligned,
79+ intptr_t o_offset, const intptr_t *o_sizes,
80+ const intptr_t *o_strides, void *r_allocated, void *r_aligned,
81+ intptr_t r_offset, const intptr_t *r_sizes,
82+ const intptr_t *r_strides, uint64_t *lo_allocated,
83+ uint64_t *lo_aligned) {
84+ assert (rank == this ->rank ());
85+ this ->set_value (std::move (
86+ mk_tnsr (reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
87+ this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
88+ l_strides, o_allocated, o_aligned, o_offset, o_sizes,
89+ o_strides, r_allocated, r_aligned, r_offset, r_sizes,
90+ r_strides, lo_allocated, lo_aligned)));
91+ });
9992 return false ;
10093 }
10194
@@ -109,9 +102,11 @@ struct DeferredFull : public Deferred {
109102};
110103
111104ddptensor *Creator::full (const shape_type &shape, const py::object &val,
112- DTypeId dtype, uint64_t team) {
105+ DTypeId dtype, const std::string &device,
106+ uint64_t team) {
113107 auto v = mk_scalar (val, dtype);
114- return new ddptensor (defer<DeferredFull>(shape, v, dtype, mkTeam (team)));
108+ return new ddptensor (
109+ defer<DeferredFull>(shape, v, dtype, device, mkTeam (team)));
115110}
116111
117112// ***************************************************************************
@@ -121,33 +116,25 @@ struct DeferredArange : public Deferred {
121116
122117 DeferredArange () = default ;
123118 DeferredArange (uint64_t start, uint64_t end, uint64_t step, DTypeId dtype,
124- uint64_t team)
119+ const std::string &device, uint64_t team)
125120 : Deferred(dtype,
126121 {static_cast <shape_type::value_type>(
127122 (end - start + step + (step < 0 ? 1 : -1 )) / step)},
128- team, true ),
123+ device, team ),
129124 _start (start), _end(end), _step(step) {}
130125
131126 bool generate_mlir (::mlir::OpBuilder &builder, const ::mlir::Location &loc,
132127 jit::DepManager &dm) override {
133- // ::mlir::Value
134- auto transceiver = getTransceiver ();
135- auto teamV = team () == 0
136- ? ::mlir::Value ()
137- : ::imex::createIndex (loc, builder,
138- reinterpret_cast <uint64_t >(team ()));
139-
140128 auto _num = shape ()[0 ];
141-
142129 auto start = ::imex::createFloat (loc, builder, _start);
143130 auto stop = ::imex::createFloat (loc, builder, _start + _num * _step);
144131 auto num = ::imex::createIndex (loc, builder, _num);
145- auto rTyp = :: imex::ptensor::PTensorType::get (
146- shape (), imex::ptensor::toMLIR (builder, jit::getPTDType (_dtype) ));
132+ auto dtyp = jit::getPTDType ( dtype ());
133+ auto envs = jit::mkEnvs (builder, rank (), _device, team ( ));
147134
148135 dm.addVal (this ->guid (),
149- builder.create <::imex::ptensor::LinSpaceOp>(
150- loc, rTyp, start, stop, num, false , nullptr , teamV ),
136+ builder.create <::imex::ptensor::LinSpaceOp>(loc, start, stop, num,
137+ false , dtyp, envs ),
151138 [this ](uint64_t rank, void *l_allocated, void *l_aligned,
152139 intptr_t l_offset, const intptr_t *l_sizes,
153140 const intptr_t *l_strides, void *o_allocated,
@@ -157,7 +144,7 @@ struct DeferredArange : public Deferred {
157144 const intptr_t *r_sizes, const intptr_t *r_strides,
158145 uint64_t *lo_allocated, uint64_t *lo_aligned) {
159146 assert (rank == 1 );
160- assert (l_strides [0 ] == 1 );
147+ assert (o_strides [0 ] == 1 );
161148 this ->set_value (std::move (mk_tnsr (
162149 reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
163150 this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
@@ -178,9 +165,10 @@ struct DeferredArange : public Deferred {
178165};
179166
180167ddptensor *Creator::arange (uint64_t start, uint64_t end, uint64_t step,
181- DTypeId dtype, uint64_t team) {
168+ DTypeId dtype, const std::string &device,
169+ uint64_t team) {
182170 return new ddptensor (
183- defer<DeferredArange>(start, end, step, dtype, mkTeam (team)));
171+ defer<DeferredArange>(start, end, step, dtype, device, mkTeam (team)));
184172}
185173
186174// ***************************************************************************
@@ -192,27 +180,22 @@ struct DeferredLinspace : public Deferred {
192180
193181 DeferredLinspace () = default ;
194182 DeferredLinspace (double start, double end, uint64_t num, bool endpoint,
195- DTypeId dtype, uint64_t team)
196- : Deferred(dtype, {static_cast <shape_type::value_type>(num)}, team, true ),
183+ DTypeId dtype, const std::string &device, uint64_t team)
184+ : Deferred(dtype, {static_cast <shape_type::value_type>(num)}, device,
185+ team),
197186 _start (start), _end(end), _num(num), _endpoint(endpoint) {}
198187
199188 bool generate_mlir (::mlir::OpBuilder &builder, const ::mlir::Location &loc,
200189 jit::DepManager &dm) override {
201- // ::mlir::Value
202- auto teamV = team () == 0
203- ? ::mlir::Value ()
204- : ::imex::createIndex (loc, builder,
205- reinterpret_cast <uint64_t >(team ()));
206-
207190 auto start = ::imex::createFloat (loc, builder, _start);
208191 auto stop = ::imex::createFloat (loc, builder, _end);
209192 auto num = ::imex::createIndex (loc, builder, _num);
210- auto rTyp = :: imex::ptensor::PTensorType::get (
211- shape (), imex::ptensor::toMLIR (builder, jit::getPTDType (_dtype) ));
193+ auto dtyp = jit::getPTDType ( dtype ());
194+ auto envs = jit::mkEnvs (builder, rank (), _device, team ( ));
212195
213196 dm.addVal (this ->guid (),
214197 builder.create <::imex::ptensor::LinSpaceOp>(
215- loc, rTyp, start, stop, num, _endpoint, nullptr , teamV ),
198+ loc, start, stop, num, _endpoint, dtyp, envs ),
216199 [this ](uint64_t rank, void *l_allocated, void *l_aligned,
217200 intptr_t l_offset, const intptr_t *l_sizes,
218201 const intptr_t *l_strides, void *o_allocated,
@@ -244,9 +227,10 @@ struct DeferredLinspace : public Deferred {
244227};
245228
246229ddptensor *Creator::linspace (double start, double end, uint64_t num,
247- bool endpoint, DTypeId dtype, uint64_t team) {
248- return new ddptensor (
249- defer<DeferredLinspace>(start, end, num, endpoint, dtype, mkTeam (team)));
230+ bool endpoint, DTypeId dtype,
231+ const std::string &device, uint64_t team) {
232+ return new ddptensor (defer<DeferredLinspace>(start, end, num, endpoint, dtype,
233+ device, mkTeam (team)));
250234}
251235
252236// ***************************************************************************
@@ -255,11 +239,12 @@ extern DTypeId DEFAULT_FLOAT;
255239extern DTypeId DEFAULT_INT;
256240
257241std::pair<ddptensor *, bool > Creator::mk_future (const py::object &b,
242+ const std::string &device,
258243 uint64_t team, DTypeId dtype) {
259244 if (py::isinstance<ddptensor>(b)) {
260245 return {b.cast <ddptensor *>(), false };
261246 } else if (py::isinstance<py::float_>(b) || py::isinstance<py::int_>(b)) {
262- return {Creator::full ({}, b, dtype, team), true };
247+ return {Creator::full ({}, b, dtype, device, team), true };
263248 }
264249 throw std::runtime_error (
265250 " Invalid right operand to elementwise binary operation" );
0 commit comments