88#include " ddptensor/Factory.hpp"
99#include " ddptensor/Transceiver.hpp"
1010#include " ddptensor/TypeDispatch.hpp"
11+ #include " ddptensor/jit/mlir.hpp"
1112
1213#include < imex/Dialect/PTensor/IR/PTensorOps.h>
1314#include < imex/Utils/PassUtils.h>
1819#include < mlir/Dialect/Tensor/IR/Tensor.h>
1920#include < mlir/IR/Builders.h>
2021
22+ namespace DDPT {
23+
2124static const char *FORCE_DIST = getenv(" DDPT_FORCE_DIST" );
2225
2326inline uint64_t mkTeam (uint64_t team) {
@@ -36,8 +39,9 @@ struct DeferredFull : public Deferred {
3639 : Deferred(dtype, shape, team, true ), _val(val) {}
3740
3841 template <typename T> struct ValAndDType {
39- static ::mlir::Value op (::mlir::OpBuilder &builder, ::mlir::Location loc,
40- const PyScalar &val, ::imex::ptensor::DType &dtyp) {
42+ static ::mlir::Value op (::mlir::OpBuilder &builder,
43+ const ::mlir::Location &loc, const PyScalar &val,
44+ ::imex::ptensor::DType &dtyp) {
4145 dtyp = jit::PT_DTYPE<T>::value;
4246
4347 if (is_none (val)) {
@@ -54,7 +58,7 @@ struct DeferredFull : public Deferred {
5458 };
5559 };
5660
57- bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
61+ bool generate_mlir (::mlir::OpBuilder &builder, const ::mlir::Location & loc,
5862 jit::DepManager &dm) override {
5963 ::mlir::SmallVector<::mlir::Value> shp (rank ());
6064 for (auto i = 0 ; i < rank (); ++i) {
@@ -124,7 +128,7 @@ struct DeferredArange : public Deferred {
124128 team, true ),
125129 _start (start), _end(end), _step(step) {}
126130
127- bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
131+ bool generate_mlir (::mlir::OpBuilder &builder, const ::mlir::Location & loc,
128132 jit::DepManager &dm) override {
129133 // ::mlir::Value
130134 auto transceiver = getTransceiver ();
@@ -192,7 +196,7 @@ struct DeferredLinspace : public Deferred {
192196 : Deferred(dtype, {static_cast <shape_type::value_type>(num)}, team, true ),
193197 _start (start), _end(end), _num(num), _endpoint(endpoint) {}
194198
195- bool generate_mlir (::mlir::OpBuilder &builder, ::mlir::Location loc,
199+ bool generate_mlir (::mlir::OpBuilder &builder, const ::mlir::Location & loc,
196200 jit::DepManager &dm) override {
197201 // ::mlir::Value
198202 auto teamV = team () == 0
@@ -247,14 +251,15 @@ ddptensor *Creator::linspace(double start, double end, uint64_t num,
247251
248252// ***************************************************************************
249253
254+ extern DTypeId DEFAULT_FLOAT;
255+ extern DTypeId DEFAULT_INT;
256+
250257std::pair<ddptensor *, bool > Creator::mk_future (const py::object &b,
251- uint64_t team) {
258+ uint64_t team, DTypeId dtype ) {
252259 if (py::isinstance<ddptensor>(b)) {
253260 return {b.cast <ddptensor *>(), false };
254- } else if (py::isinstance<py::float_>(b)) {
255- return {Creator::full ({}, b, FLOAT64, team), true };
256- } else if (py::isinstance<py::int_>(b)) {
257- return {Creator::full ({}, b, INT64, team), true };
261+ } else if (py::isinstance<py::float_>(b) || py::isinstance<py::int_>(b)) {
262+ return {Creator::full ({}, b, dtype, team), true };
258263 }
259264 throw std::runtime_error (
260265 " Invalid right operand to elementwise binary operation" );
@@ -263,3 +268,4 @@ std::pair<ddptensor *, bool> Creator::mk_future(const py::object &b,
263268FACTORY_INIT (DeferredFull, F_FULL);
264269FACTORY_INIT (DeferredArange, F_ARANGE);
265270FACTORY_INIT (DeferredLinspace, F_LINSPACE);
271+ } // namespace DDPT
0 commit comments