@@ -142,6 +142,65 @@ struct DeferredAsType : public Deferred {
142142 }
143143};
144144
145+ // ***************************************************************************
146+
147+ struct DeferredToDevice : public Deferred {
148+ id_type _a;
149+
150+ DeferredToDevice () = default ;
151+ DeferredToDevice (const array_i::future_type &a, const std::string &device)
152+ : Deferred(a.dtype(), a.shape(), device, a.team()), _a(a.guid()) {}
153+
154+ bool generate_mlir (::mlir::OpBuilder &builder, const ::mlir::Location &loc,
155+ jit::DepManager &dm) override {
156+ auto av = dm.getDependent (builder, _a);
157+
158+ auto srcType = av.getType ().dyn_cast <::imex::ndarray::NDArrayType>();
159+ assert (srcType);
160+ // copy envs, drop gpu env (if any)
161+ auto srcEnvs = srcType.getEnvironments ();
162+ ::mlir::SmallVector<::mlir::Attribute> envs;
163+ for (auto e : srcEnvs) {
164+ if (!::mlir::isa<::imex::region::GPUEnvAttr>(e)) {
165+ envs.emplace_back (e);
166+ }
167+ }
168+ // append device attr
169+ if (!_device.empty ()) {
170+ envs.emplace_back (
171+ ::imex::region::GPUEnvAttr::get (builder.getStringAttr(_device)));
172+ }
173+ auto outType = ::imex::ndarray::NDArrayType::get (srcType.getShape (),
174+ srcType.getElementType (),
175+ envs, srcType.getLayout ());
176+ auto res = builder.create <::imex::ndarray::CopyOp>(loc, outType, av);
177+ dm.addVal (this ->guid (), res,
178+ [this ](uint64_t rank, void *l_allocated, void *l_aligned,
179+ intptr_t l_offset, const intptr_t *l_sizes,
180+ const intptr_t *l_strides, void *o_allocated,
181+ void *o_aligned, intptr_t o_offset,
182+ const intptr_t *o_sizes, const intptr_t *o_strides,
183+ void *r_allocated, void *r_aligned, intptr_t r_offset,
184+ const intptr_t *r_sizes, const intptr_t *r_strides,
185+ uint64_t *lo_allocated, uint64_t *lo_aligned) {
186+ auto t = mk_tnsr (reinterpret_cast <Transceiver *>(this ->team ()),
187+ _dtype, this ->shape (), l_allocated, l_aligned,
188+ l_offset, l_sizes, l_strides, o_allocated,
189+ o_aligned, o_offset, o_sizes, o_strides,
190+ r_allocated, r_aligned, r_offset, r_sizes,
191+ r_strides, lo_allocated, lo_aligned);
192+ this ->set_value (std::move (t));
193+ });
194+ return false ;
195+ }
196+
197+ FactoryId factory () const { return F_TODEVICE; }
198+
199+ template <typename S> void serialize (S &ser) {
200+ ser.template value <sizeof (_a)>(_a);
201+ }
202+ };
203+
145204FutureArray *ManipOp::reshape (const FutureArray &a, const shape_type &shape,
146205 const py::object ©) {
147206 auto doCopy = copy.is_none ()
@@ -155,12 +214,18 @@ FutureArray *ManipOp::reshape(const FutureArray &a, const shape_type &shape,
155214 return new FutureArray (defer<DeferredReshape>(a.get (), shape, doCopy));
156215}
157216
158- FutureArray *AsType ::astype (const FutureArray &a, DTypeId dtype,
159- const py::object ©) {
217+ FutureArray *ManipOp ::astype (const FutureArray &a, DTypeId dtype,
218+ const py::object ©) {
160219 auto doCopy = copy.is_none () ? false : copy.cast <bool >();
161220 return new FutureArray (defer<DeferredAsType>(a.get (), dtype, doCopy));
162221}
163222
223+ FutureArray *ManipOp::to_device (const FutureArray &a,
224+ const std::string &device) {
225+ return new FutureArray (defer<DeferredToDevice>(a.get (), device));
226+ }
227+
164228FACTORY_INIT (DeferredReshape, F_RESHAPE);
165229FACTORY_INIT (DeferredAsType, F_ASTYPE);
230+ FACTORY_INIT (DeferredToDevice, F_TODEVICE);
166231} // namespace SHARPY
0 commit comments