Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit f0dd8b6

Browse files
authored
Add to_device method (#74)
* add array.to_device method * test fail because remote temporaries pass removes memref.copy thus effectively producing a view
1 parent fc01aa1 commit f0dd8b6

File tree

7 files changed

+97
-14
lines changed

7 files changed

+97
-14
lines changed

sharpy/ndarray.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def __repr__(self):
5454
def astype(self, dtype, copy=False):
5555
return ndarray(self._t.astype(dtype, copy))
5656

57+
def to_device(self, device=""):
58+
return ndarray(self._t.to_device(device))
59+
5760
def __getitem__(self, key):
5861
key = key if isinstance(key, tuple) else (key,)
5962
key = [slicefy(x) for x in key]

src/ManipOp.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,65 @@ struct DeferredAsType : public Deferred {
142142
}
143143
};
144144

145+
// ***************************************************************************
146+
147+
struct DeferredToDevice : public Deferred {
148+
id_type _a;
149+
150+
DeferredToDevice() = default;
151+
DeferredToDevice(const array_i::future_type &a, const std::string &device)
152+
: Deferred(a.dtype(), a.shape(), device, a.team()), _a(a.guid()) {}
153+
154+
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
155+
jit::DepManager &dm) override {
156+
auto av = dm.getDependent(builder, _a);
157+
158+
auto srcType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>();
159+
assert(srcType);
160+
// copy envs, drop gpu env (if any)
161+
auto srcEnvs = srcType.getEnvironments();
162+
::mlir::SmallVector<::mlir::Attribute> envs;
163+
for (auto e : srcEnvs) {
164+
if (!::mlir::isa<::imex::region::GPUEnvAttr>(e)) {
165+
envs.emplace_back(e);
166+
}
167+
}
168+
// append device attr
169+
if (!_device.empty()) {
170+
envs.emplace_back(
171+
::imex::region::GPUEnvAttr::get(builder.getStringAttr(_device)));
172+
}
173+
auto outType = ::imex::ndarray::NDArrayType::get(srcType.getShape(),
174+
srcType.getElementType(),
175+
envs, srcType.getLayout());
176+
auto res = builder.create<::imex::ndarray::CopyOp>(loc, outType, av);
177+
dm.addVal(this->guid(), res,
178+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
179+
intptr_t l_offset, const intptr_t *l_sizes,
180+
const intptr_t *l_strides, void *o_allocated,
181+
void *o_aligned, intptr_t o_offset,
182+
const intptr_t *o_sizes, const intptr_t *o_strides,
183+
void *r_allocated, void *r_aligned, intptr_t r_offset,
184+
const intptr_t *r_sizes, const intptr_t *r_strides,
185+
uint64_t *lo_allocated, uint64_t *lo_aligned) {
186+
auto t = mk_tnsr(reinterpret_cast<Transceiver *>(this->team()),
187+
_dtype, this->shape(), l_allocated, l_aligned,
188+
l_offset, l_sizes, l_strides, o_allocated,
189+
o_aligned, o_offset, o_sizes, o_strides,
190+
r_allocated, r_aligned, r_offset, r_sizes,
191+
r_strides, lo_allocated, lo_aligned);
192+
this->set_value(std::move(t));
193+
});
194+
return false;
195+
}
196+
197+
FactoryId factory() const { return F_TODEVICE; }
198+
199+
template <typename S> void serialize(S &ser) {
200+
ser.template value<sizeof(_a)>(_a);
201+
}
202+
};
203+
145204
FutureArray *ManipOp::reshape(const FutureArray &a, const shape_type &shape,
146205
const py::object &copy) {
147206
auto doCopy = copy.is_none()
@@ -155,12 +214,18 @@ FutureArray *ManipOp::reshape(const FutureArray &a, const shape_type &shape,
155214
return new FutureArray(defer<DeferredReshape>(a.get(), shape, doCopy));
156215
}
157216

158-
FutureArray *AsType::astype(const FutureArray &a, DTypeId dtype,
159-
const py::object &copy) {
217+
FutureArray *ManipOp::astype(const FutureArray &a, DTypeId dtype,
218+
const py::object &copy) {
160219
auto doCopy = copy.is_none() ? false : copy.cast<bool>();
161220
return new FutureArray(defer<DeferredAsType>(a.get(), dtype, doCopy));
162221
}
163222

223+
FutureArray *ManipOp::to_device(const FutureArray &a,
224+
const std::string &device) {
225+
return new FutureArray(defer<DeferredToDevice>(a.get(), device));
226+
}
227+
164228
FACTORY_INIT(DeferredReshape, F_RESHAPE);
165229
FACTORY_INIT(DeferredAsType, F_ASTYPE);
230+
FACTORY_INIT(DeferredToDevice, F_TODEVICE);
166231
} // namespace SHARPY

src/_sharpy.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ PYBIND11_MODULE(_sharpy, m) {
203203
.def("__index__",
204204
[](const FutureArray &f) { REPL_SYNC_RETURN(f, __int__); })
205205
// attributes returning a new FutureArray
206-
.def("astype", &AsType::astype)
206+
.def("astype", &ManipOp::astype)
207+
.def("to_device", &ManipOp::to_device)
207208
.def("__getitem__", &GetItem::__getitem__)
208209
.def("__setitem__", &SetItem::__setitem__)
209210
.def("map", &SetItem::map);

src/include/sharpy/CppTypes.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ enum FactoryId : int {
343343
F_SERVICE,
344344
F_SETITEM,
345345
F_ASTYPE,
346+
F_TODEVICE,
346347
FACTORY_LAST
347348
};
348349

src/include/sharpy/ManipOp.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ namespace SHARPY {
1414
struct ManipOp {
1515
static FutureArray *reshape(const FutureArray &a, const shape_type &shape,
1616
const py::object &copy);
17-
};
1817

19-
struct AsType {
2018
static FutureArray *astype(const FutureArray &a, DTypeId dtype,
2119
const py::object &copy);
20+
21+
static FutureArray *to_device(const FutureArray &a,
22+
const std::string &device);
2223
};
2324
} // namespace SHARPY

src/jit/mlir.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -354,15 +354,14 @@ void DepManager::deliver(std::vector<intptr_t> &outputV, uint64_t sz) {
354354
// no sizes/stride needed, just skip
355355
pos += memref_sz(1);
356356
// call finalization callback
357-
v->second(
358-
rank, t_allocated[0], t_aligned[0], t_offset[0], t_sizes[0],
359-
t_strides[0], // lhsHalo
360-
t_allocated[1], t_aligned[1], t_offset[1], t_sizes[1],
361-
t_strides[1], // lData
362-
t_allocated[2], t_aligned[2], t_offset[2], t_sizes[2],
363-
t_strides[2], // rhsHalo
364-
lo_allocated,
365-
lo_aligned + lo_offset // local offset is 1d array of uint64_t
357+
v->second(rank, t_allocated[0], t_aligned[0], t_offset[0], t_sizes[0],
358+
t_strides[0], // lhsHalo
359+
t_allocated[1], t_aligned[1], t_offset[1], t_sizes[1],
360+
t_strides[1], // lData
361+
t_allocated[2], t_aligned[2], t_offset[2], t_sizes[2],
362+
t_strides[2], // rhsHalo
363+
lo_allocated,
364+
lo_aligned + lo_offset // local offset is 1d array of uint64_t
366365
);
367366
} else { // 0d array or non-dist
368367
pos += getMR(rank, &output[pos], t_allocated[1], t_aligned[1],
@@ -631,6 +630,7 @@ JIT::JIT()
631630
_context.getOrLoadDialect<::mlir::arith::ArithDialect>();
632631
_context.getOrLoadDialect<::mlir::func::FuncDialect>();
633632
_context.getOrLoadDialect<::mlir::linalg::LinalgDialect>();
633+
_context.getOrLoadDialect<::imex::region::RegionDialect>();
634634
// create the pass pipeline from string
635635
if (::mlir::failed(::mlir::parsePassPipeline(pass_pipeline, _pm)))
636636
throw std::runtime_error("failed to parse pass pipeline");

test/test_manip.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,15 @@ def test_astype_suite(self):
6868
b = a.astype(to_type)
6969
assert b.dtype == to_type
7070
assert numpy.allclose(sp.to_numpy(b), [0, 1, 2, 3, 4, 5, 6, 7])
71+
72+
@pytest.mark.skip(reason="FIXME copy is eliminated")
73+
def test_todevice_host2host(self):
74+
a = sp.arange(0, 8, 1, sp.int32)
75+
b = a.to_device()
76+
assert numpy.allclose(sp.to_numpy(b), [0, 1, 2, 3, 4, 5, 6, 7])
77+
78+
@pytest.mark.skip(reason="FIXME copy is eliminated, device support")
79+
def test_todevice_host2gpu(self):
80+
a = sp.arange(0, 8, 1, sp.int32)
81+
b = a.to_device(device="GPU")
82+
assert numpy.allclose(sp.to_numpy(b), [0, 1, 2, 3, 4, 5, 6, 7])

0 commit comments

Comments
 (0)