Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit fc01aa1

Browse files
authored
Add astype method (#73)
1 parent 57872be commit fc01aa1

File tree

9 files changed

+160
-34
lines changed

9 files changed

+160
-34
lines changed

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
643547384513a4ae510dcd342ae1637dc802166c
1+
ea310778459a22ff8889a3c5cf3aba39a6e9e772

sharpy/ndarray.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
The array class for sharpy, a distributed implementation of the
2+
The array class for sharpy, a distributed implementation of the
33
array API as defined here: https://data-apis.org/array-api/latest
44
"""
55
#
@@ -51,6 +51,9 @@ def __repr__(self):
5151
for att in api.attributes:
5252
exec(f"{att} = property(lambda self: self._t.{att})")
5353

54+
def astype(self, dtype, copy=False):
55+
return ndarray(self._t.astype(dtype, copy))
56+
5457
def __getitem__(self, key):
5558
key = key if isinstance(key, tuple) else (key,)
5659
key = [slicefy(x) for x in key]

src/ManipOp.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
*/
66

77
#include "sharpy/ManipOp.hpp"
8-
#include "sharpy/NDArray.hpp"
98
#include "sharpy/Deferred.hpp"
109
#include "sharpy/Factory.hpp"
10+
#include "sharpy/NDArray.hpp"
1111
#include "sharpy/TypeDispatch.hpp"
1212
#include "sharpy/jit/mlir.hpp"
1313

@@ -82,8 +82,68 @@ struct DeferredReshape : public Deferred {
8282
}
8383
};
8484

85+
// ***************************************************************************
86+
87+
struct DeferredAsType : public Deferred {
88+
id_type _a;
89+
bool _copy;
90+
91+
DeferredAsType() = default;
92+
DeferredAsType(const array_i::future_type &a, DTypeId dtype, bool copy)
93+
: Deferred(dtype, a.shape(), a.device(), a.team()), _a(a.guid()),
94+
_copy(copy) {}
95+
96+
template <typename T> struct convDType {
97+
static ::imex::ndarray::DType op() { return jit::PT_DTYPE<T>::value; };
98+
};
99+
100+
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
101+
jit::DepManager &dm) override {
102+
const auto dtype = this->dtype();
103+
auto av = dm.getDependent(builder, _a);
104+
105+
auto copyAttr = ::imex::getIntAttr(builder, _copy, 1);
106+
// construct NDArrayType with same shape and given dtype
107+
::imex::ndarray::DType ndDType = dispatch<convDType>(dtype);
108+
auto mlirElType = ::imex::ndarray::toMLIR(builder, ndDType);
109+
auto arType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>();
110+
assert(arType);
111+
auto outType = arType.cloneWith(std::nullopt, mlirElType);
112+
auto res = builder.create<::imex::ndarray::CastElemTypeOp>(loc, outType, av,
113+
copyAttr);
114+
dm.addVal(this->guid(), res,
115+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
116+
intptr_t l_offset, const intptr_t *l_sizes,
117+
const intptr_t *l_strides, void *o_allocated,
118+
void *o_aligned, intptr_t o_offset,
119+
const intptr_t *o_sizes, const intptr_t *o_strides,
120+
void *r_allocated, void *r_aligned, intptr_t r_offset,
121+
const intptr_t *r_sizes, const intptr_t *r_strides,
122+
uint64_t *lo_allocated, uint64_t *lo_aligned) {
123+
auto t = mk_tnsr(reinterpret_cast<Transceiver *>(this->team()),
124+
_dtype, this->shape(), l_allocated, l_aligned,
125+
l_offset, l_sizes, l_strides, o_allocated,
126+
o_aligned, o_offset, o_sizes, o_strides,
127+
r_allocated, r_aligned, r_offset, r_sizes,
128+
r_strides, lo_allocated, lo_aligned);
129+
if (Registry::has(_a)) {
130+
t->set_base(Registry::get(_a).get());
131+
} // else _a is a temporary and was dropped
132+
this->set_value(std::move(t));
133+
});
134+
return false;
135+
}
136+
137+
FactoryId factory() const { return F_ASTYPE; }
138+
139+
template <typename S> void serialize(S &ser) {
140+
ser.template value<sizeof(_a)>(_a);
141+
ser.template value<sizeof(_copy)>(_copy);
142+
}
143+
};
144+
85145
FutureArray *ManipOp::reshape(const FutureArray &a, const shape_type &shape,
86-
const py::object &copy) {
146+
const py::object &copy) {
87147
auto doCopy = copy.is_none()
88148
? DeferredReshape::COPY_POSSIBLE
89149
: (copy.cast<bool>() ? DeferredReshape::COPY_ALWAYS
@@ -95,5 +155,12 @@ FutureArray *ManipOp::reshape(const FutureArray &a, const shape_type &shape,
95155
return new FutureArray(defer<DeferredReshape>(a.get(), shape, doCopy));
96156
}
97157

158+
FutureArray *AsType::astype(const FutureArray &a, DTypeId dtype,
159+
const py::object &copy) {
160+
auto doCopy = copy.is_none() ? false : copy.cast<bool>();
161+
return new FutureArray(defer<DeferredAsType>(a.get(), dtype, doCopy));
162+
}
163+
98164
FACTORY_INIT(DeferredReshape, F_RESHAPE);
165+
FACTORY_INIT(DeferredAsType, F_ASTYPE);
99166
} // namespace SHARPY

src/SetGetItem.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
#include "sharpy/SetGetItem.hpp"
99
#include "sharpy/CollComm.hpp"
1010
#include "sharpy/Creator.hpp"
11-
#include "sharpy/NDArray.hpp"
1211
#include "sharpy/Deferred.hpp"
1312
#include "sharpy/Factory.hpp"
1413
#include "sharpy/Mediator.hpp"
14+
#include "sharpy/NDArray.hpp"
1515
#include "sharpy/NDSlice.hpp"
1616
#include "sharpy/Transceiver.hpp"
1717
#include "sharpy/TypeDispatch.hpp"
@@ -152,8 +152,7 @@ struct DeferredSetItem : public Deferred {
152152
NDSlice _slc;
153153

154154
DeferredSetItem() = default;
155-
DeferredSetItem(const array_i::future_type &a,
156-
const array_i::future_type &b,
155+
DeferredSetItem(const array_i::future_type &a, const array_i::future_type &b,
157156
const std::vector<py::slice> &v)
158157
: Deferred(a.dtype(), a.shape(), a.device(), a.team(), a.guid()),
159158
_a(a.guid()), _b(b.guid()), _slc(v, a.shape()) {}
@@ -308,22 +307,24 @@ struct DeferredGetItem : public Deferred {
308307
// ***************************************************************************
309308

310309
FutureArray *GetItem::__getitem__(const FutureArray &a,
311-
const std::vector<py::slice> &v) {
310+
const std::vector<py::slice> &v) {
312311
auto afut = a.get();
313312
NDSlice slc(v, afut.shape());
314313
return new FutureArray(defer<DeferredGetItem>(afut, std::move(slc)));
315314
}
316315

317-
GetItem::py_future_type GetItem::get_locals(const FutureArray &a, py::handle h) {
316+
GetItem::py_future_type GetItem::get_locals(const FutureArray &a,
317+
py::handle h) {
318318
return defer<DeferredGetLocals>(a.get(), h);
319319
}
320320

321321
GetItem::py_future_type GetItem::gather(const FutureArray &a, rank_type root) {
322322
return defer<DeferredGather>(a.get(), root);
323323
}
324324

325-
FutureArray *SetItem::__setitem__(FutureArray &a, const std::vector<py::slice> &v,
326-
const py::object &b) {
325+
FutureArray *SetItem::__setitem__(FutureArray &a,
326+
const std::vector<py::slice> &v,
327+
const py::object &b) {
327328
auto bb =
328329
Creator::mk_future(b, a.get().device(), a.get().team(), a.get().dtype());
329330
a.put(defer<DeferredSetItem>(a.get(), bb.first->get(), v));

src/_sharpy.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ void sync_promises() {
112112

113113
/// trigger compile&run and return future value
114114
#define PY_SYNC_RETURN(_f) \
115-
int vtWaitSym, vtSHARPYClass; \
116-
VT(VT_classdef, "sharpy", &vtSHARPYClass); \
117-
VT(VT_funcdef, "wait", vtSHARPYClass, &vtWaitSym); \
115+
int vtWaitSym, vtSHARPYClass; \
116+
VT(VT_classdef, "sharpy", &vtSHARPYClass); \
117+
VT(VT_funcdef, "wait", vtSHARPYClass, &vtWaitSym); \
118118
VT(VT_begin, vtWaitSym); \
119119
py::gil_scoped_release release; \
120120
Service::run(); \
@@ -124,22 +124,21 @@ void sync_promises() {
124124

125125
/// trigger compile&run and return given attribute _x
126126
#define SYNC_RETURN(_f, _a) \
127-
int vtWaitSym, vtSHARPYClass; \
128-
VT(VT_classdef, "sharpy", &vtSHARPYClass); \
129-
VT(VT_funcdef, "wait", vtSHARPYClass, &vtWaitSym); \
127+
int vtWaitSym, vtSHARPYClass; \
128+
VT(VT_classdef, "sharpy", &vtSHARPYClass); \
129+
VT(VT_funcdef, "wait", vtSHARPYClass, &vtWaitSym); \
130130
VT(VT_begin, vtWaitSym); \
131131
py::gil_scoped_release release; \
132132
Service::run(); \
133-
auto r = (_f).get().get()->_a(); \
133+
auto r = (_f).get().get() -> _a(); \
134134
VT(VT_end, vtWaitSym); \
135135
return r
136136

137137
/// Replicate sharpy/future and SYNC_RETURN attribute _a
138138
#define REPL_SYNC_RETURN(_f, _a) \
139-
auto r_ = std::unique_ptr<FutureArray>(Service::replicate(f)); \
139+
auto r_ = std::unique_ptr<FutureArray>(Service::replicate(f)); \
140140
SYNC_RETURN(r_->get(), _a)
141141

142-
143142
// Finally our Python module
144143
PYBIND11_MODULE(_sharpy, m) {
145144

@@ -183,13 +182,13 @@ PYBIND11_MODULE(_sharpy, m) {
183182

184183
py::class_<FutureArray>(m, "SHARPYFuture")
185184
// attributes we can get from the future itself
186-
.def_property_readonly("dtype",
187-
[](const FutureArray &f) { return f.get().dtype(); })
188-
.def_property_readonly("ndim",
189-
[](const FutureArray &f) { return f.get().rank(); })
185+
.def_property_readonly(
186+
"dtype", [](const FutureArray &f) { return f.get().dtype(); })
187+
.def_property_readonly(
188+
"ndim", [](const FutureArray &f) { return f.get().rank(); })
190189
// attributes we can get from future without additional computation
191-
.def_property_readonly("shape",
192-
[](const FutureArray &f) { SYNC_RETURN(f, shape); })
190+
.def_property_readonly(
191+
"shape", [](const FutureArray &f) { SYNC_RETURN(f, shape); })
193192
.def_property_readonly("size",
194193
[](const FutureArray &f) { SYNC_RETURN(f, size); })
195194
.def("__len__", [](const FutureArray &f) { SYNC_RETURN(f, __len__); })
@@ -199,10 +198,12 @@ PYBIND11_MODULE(_sharpy, m) {
199198
[](const FutureArray &f) { REPL_SYNC_RETURN(f, __bool__); })
200199
.def("__float__",
201200
[](const FutureArray &f) { REPL_SYNC_RETURN(f, __float__); })
202-
.def("__int__", [](const FutureArray &f) { REPL_SYNC_RETURN(f, __int__); })
201+
.def("__int__",
202+
[](const FutureArray &f) { REPL_SYNC_RETURN(f, __int__); })
203203
.def("__index__",
204204
[](const FutureArray &f) { REPL_SYNC_RETURN(f, __int__); })
205205
// attributes returning a new FutureArray
206+
.def("astype", &AsType::astype)
206207
.def("__getitem__", &GetItem::__getitem__)
207208
.def("__setitem__", &SetItem::__setitem__)
208209
.def("map", &SetItem::map);
@@ -213,7 +214,6 @@ PYBIND11_MODULE(_sharpy, m) {
213214
.def("seed", &Random::seed)
214215
.def("uniform", &Random::rand);
215216

216-
217217
// py::class_<dpdlpack>(m, "dpdlpack")
218218
// .def("__dlpack__", &dpdlpack.__dlpack__);
219219
}

src/include/sharpy/CppTypes.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ enum FactoryId : int {
342342
F_RESHAPE,
343343
F_SERVICE,
344344
F_SETITEM,
345+
F_ASTYPE,
345346
FACTORY_LAST
346347
};
347348

src/include/sharpy/ManipOp.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66

77
#pragma once
88

9-
#include "UtilsAndTypes.hpp"
109
#include "FutureArray.hpp"
10+
#include "UtilsAndTypes.hpp"
1111
#include "p2c_ids.hpp"
1212

1313
namespace SHARPY {
1414
struct ManipOp {
1515
static FutureArray *reshape(const FutureArray &a, const shape_type &shape,
16-
const py::object &copy);
16+
const py::object &copy);
17+
};
18+
19+
struct AsType {
20+
static FutureArray *astype(const FutureArray &a, DTypeId dtype,
21+
const py::object &copy);
1722
};
1823
} // namespace SHARPY

src/include/sharpy/SetGetItem.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
#pragma once
99

10-
#include "UtilsAndTypes.hpp"
1110
#include "FutureArray.hpp"
11+
#include "UtilsAndTypes.hpp"
1212
#include "p2c_ids.hpp"
1313

1414
namespace SHARPY {
@@ -18,16 +18,17 @@ struct GetItem {
1818
using py_future_type = std::shared_future<py::object>;
1919

2020
static FutureArray *__getitem__(const FutureArray &a,
21-
const std::vector<py::slice> &v);
21+
const std::vector<py::slice> &v);
2222
static py::object get_slice(const FutureArray &a,
2323
const std::vector<py::slice> &v);
2424
static py_future_type get_locals(const FutureArray &a, py::handle h);
2525
static py_future_type gather(const FutureArray &a, rank_type root);
2626
};
2727

2828
struct SetItem {
29-
static FutureArray *__setitem__(FutureArray &a, const std::vector<py::slice> &v,
30-
const py::object &b);
29+
static FutureArray *__setitem__(FutureArray &a,
30+
const std::vector<py::slice> &v,
31+
const py::object &b);
3132
static FutureArray *map(FutureArray &a, py::object &b);
3233
};
3334
} // namespace SHARPY

test/test_manip.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
import sharpy as sp
2+
import numpy
13
from utils import runAndCompare
24
import pytest
5+
import itertools
6+
from mpi4py import MPI
7+
import os
38

49

510
class TestManip:
@@ -20,3 +25,46 @@ def doit(aapi):
2025
return aapi.reshape(c, [5, 6])
2126

2227
assert runAndCompare(doit)
28+
29+
def test_astype_f64i32(self):
30+
def doit(aapi):
31+
a = aapi.arange(0, 8, 1, aapi.float64)
32+
a += 0.3
33+
return a.astype(aapi.int32)
34+
35+
assert runAndCompare(doit)
36+
37+
def test_astype_view(self):
38+
a = sp.arange(0, 8, 1, sp.int32)
39+
b = a.astype(sp.int32)
40+
b[:3] = 5
41+
assert b.dtype == sp.int32
42+
assert numpy.allclose(sp.to_numpy(a), [5, 5, 5, 3, 4, 5, 6, 7])
43+
44+
@pytest.mark.skipif(
45+
MPI.COMM_WORLD.size > 1 or os.getenv("SHARPY_FORCE_DIST"),
46+
reason="FIXME multi-proc",
47+
)
48+
def test_astype_copy(self):
49+
a = sp.arange(0, 8, 1, sp.int32)
50+
b = a.astype(sp.int32, copy=True)
51+
b[:3] = 5
52+
assert b.dtype == sp.int32
53+
assert numpy.allclose(sp.to_numpy(a), [0, 1, 2, 3, 4, 5, 6, 7])
54+
assert numpy.allclose(sp.to_numpy(b), [5, 5, 5, 3, 4, 5, 6, 7])
55+
56+
def test_astype_suite(self):
57+
dtype_list = [
58+
sp.int32,
59+
sp.int64,
60+
sp.uint32,
61+
sp.uint64,
62+
sp.float32,
63+
sp.float64,
64+
]
65+
66+
for from_type, to_type in itertools.product(dtype_list, dtype_list):
67+
a = sp.arange(0, 8, 1, from_type)
68+
b = a.astype(to_type)
69+
assert b.dtype == to_type
70+
assert numpy.allclose(sp.to_numpy(b), [0, 1, 2, 3, 4, 5, 6, 7])

0 commit comments

Comments
 (0)