Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 57872be

Browse files
authored
updating to new imex (ptensor -> ndarray) (#72)
* updating to new imex (ptensor -> ndarray) * CI: running shallow water in parallel
1 parent 8f168b2 commit 57872be

File tree

17 files changed

+169
-166
lines changed

17 files changed

+169
-166
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ jobs:
155155
mpirun -n 4 python -u ./wave_equation.py
156156
python -u ./shallow_water.py -d f32
157157
python -u ./shallow_water.py -d f64
158+
mpirun -n 2 python -u ./shallow_water.py
159+
mpirun -n 3 python -u ./shallow_water.py
160+
mpirun -n 4 python -u ./shallow_water.py
158161
cd -
159162
- name: Cleanup
160163
run: |

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
b36606fb8697e46adc6f834ffacef06613abd829
1+
643547384513a4ae510dcd342ae1637dc802166c

src/Creator.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include "sharpy/jit/mlir.hpp"
1212

1313
#include <imex/Dialect/Dist/IR/DistOps.h>
14-
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
14+
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
1515
#include <imex/Utils/PassUtils.h>
1616

1717
#include <mlir/Dialect/Arith/IR/Arith.h>
@@ -42,7 +42,7 @@ struct DeferredFull : public Deferred {
4242
template <typename T> struct ValAndDType {
4343
static ::mlir::Value op(::mlir::OpBuilder &builder,
4444
const ::mlir::Location &loc, const PyScalar &val,
45-
::imex::ptensor::DType &dtyp) {
45+
::imex::ndarray::DType &dtyp) {
4646
dtyp = jit::PT_DTYPE<T>::value;
4747

4848
if (is_none(val)) {
@@ -66,13 +66,13 @@ struct DeferredFull : public Deferred {
6666
shp[i] = ::imex::createIndex(loc, builder, shape()[i]);
6767
}
6868

69-
::imex::ptensor::DType dtyp;
69+
::imex::ndarray::DType dtyp;
7070
::mlir::Value val = dispatch<ValAndDType>(_dtype, builder, loc, _val, dtyp);
7171
auto envs = jit::mkEnvs(builder, rank(), _device, team());
7272

7373
dm.addVal(
7474
this->guid(),
75-
builder.create<::imex::ptensor::CreateOp>(loc, shp, dtyp, val, envs),
75+
builder.create<::imex::ndarray::CreateOp>(loc, shp, dtyp, val, envs),
7676
[this](uint64_t rank, void *l_allocated, void *l_aligned,
7777
intptr_t l_offset, const intptr_t *l_sizes,
7878
const intptr_t *l_strides, void *o_allocated, void *o_aligned,
@@ -133,7 +133,7 @@ struct DeferredArange : public Deferred {
133133
auto envs = jit::mkEnvs(builder, rank(), _device, team());
134134

135135
dm.addVal(this->guid(),
136-
builder.create<::imex::ptensor::LinSpaceOp>(loc, start, stop, num,
136+
builder.create<::imex::ndarray::LinSpaceOp>(loc, start, stop, num,
137137
false, dtyp, envs),
138138
[this](uint64_t rank, void *l_allocated, void *l_aligned,
139139
intptr_t l_offset, const intptr_t *l_sizes,
@@ -194,7 +194,7 @@ struct DeferredLinspace : public Deferred {
194194
auto envs = jit::mkEnvs(builder, rank(), _device, team());
195195

196196
dm.addVal(this->guid(),
197-
builder.create<::imex::ptensor::LinSpaceOp>(
197+
builder.create<::imex::ndarray::LinSpaceOp>(
198198
loc, start, stop, num, _endpoint, dtyp, envs),
199199
[this](uint64_t rank, void *l_allocated, void *l_aligned,
200200
intptr_t l_offset, const intptr_t *l_sizes,

src/Deferred.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "include/sharpy/itac.hpp"
1616
#include "include/sharpy/jit/mlir.hpp"
1717

18-
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
18+
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
1919
#include <mlir/Dialect/Func/IR/FuncOps.h>
2020
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
2121
#include <oneapi/tbb/concurrent_queue.h>

src/EWBinOp.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,68 +16,68 @@
1616
#include "sharpy/jit/mlir.hpp"
1717

1818
#include <imex/Dialect/Dist/IR/DistOps.h>
19-
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
19+
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
2020
#include <mlir/Dialect/Shape/IR/Shape.h>
2121
#include <mlir/IR/Builders.h>
2222

2323
namespace SHARPY {
2424

25-
// convert id of our binop to id of imex::ptensor binop
26-
static ::imex::ptensor::EWBinOpId sharpy2mlir(const EWBinOpId bop) {
25+
// convert id of our binop to id of imex::ndarray binop
26+
static ::imex::ndarray::EWBinOpId sharpy2mlir(const EWBinOpId bop) {
2727
switch (bop) {
2828
case __ADD__:
2929
case ADD:
3030
case __RADD__:
31-
return ::imex::ptensor::ADD;
31+
return ::imex::ndarray::ADD;
3232
case ATAN2:
33-
return ::imex::ptensor::ATAN2;
33+
return ::imex::ndarray::ATAN2;
3434
case __FLOORDIV__:
3535
case FLOOR_DIVIDE:
3636
case __RFLOORDIV__:
37-
return ::imex::ptensor::FLOOR_DIVIDE;
37+
return ::imex::ndarray::FLOOR_DIVIDE;
3838
// __MATMUL__ is handled before dispatching, see below
3939
case __MUL__:
4040
case MULTIPLY:
4141
case __RMUL__:
42-
return ::imex::ptensor::MULTIPLY;
42+
return ::imex::ndarray::MULTIPLY;
4343
case __SUB__:
4444
case SUBTRACT:
4545
case __RSUB__:
46-
return ::imex::ptensor::SUBTRACT;
46+
return ::imex::ndarray::SUBTRACT;
4747
case __TRUEDIV__:
4848
case DIVIDE:
4949
case __RTRUEDIV__:
50-
return ::imex::ptensor::TRUE_DIVIDE;
50+
return ::imex::ndarray::TRUE_DIVIDE;
5151
case __POW__:
5252
case POW:
5353
case __RPOW__:
54-
return ::imex::ptensor::POWER;
54+
return ::imex::ndarray::POWER;
5555
case LOGADDEXP:
56-
return ::imex::ptensor::LOGADDEXP;
56+
return ::imex::ndarray::LOGADDEXP;
5757
case __LSHIFT__:
5858
case BITWISE_LEFT_SHIFT:
5959
case __RLSHIFT__:
60-
return ::imex::ptensor::BITWISE_LEFT_SHIFT;
60+
return ::imex::ndarray::BITWISE_LEFT_SHIFT;
6161
case __MOD__:
6262
case REMAINDER:
6363
case __RMOD__:
64-
return ::imex::ptensor::MODULO;
64+
return ::imex::ndarray::MODULO;
6565
case __RSHIFT__:
6666
case BITWISE_RIGHT_SHIFT:
6767
case __RRSHIFT__:
68-
return ::imex::ptensor::BITWISE_RIGHT_SHIFT;
68+
return ::imex::ndarray::BITWISE_RIGHT_SHIFT;
6969
case __AND__:
7070
case BITWISE_AND:
7171
case __RAND__:
72-
return ::imex::ptensor::BITWISE_AND;
72+
return ::imex::ndarray::BITWISE_AND;
7373
case __OR__:
7474
case BITWISE_OR:
7575
case __ROR__:
76-
return ::imex::ptensor::BITWISE_OR;
76+
return ::imex::ndarray::BITWISE_OR;
7777
case __XOR__:
7878
case BITWISE_XOR:
7979
case __RXOR__:
80-
return ::imex::ptensor::BITWISE_XOR;
80+
return ::imex::ndarray::BITWISE_XOR;
8181
default:
8282
throw std::runtime_error("Unknown/invalid elementwise binary operation");
8383
}
@@ -120,10 +120,10 @@ struct DeferredEWBinOp : public Deferred {
120120
auto av = dm.getDependent(builder, _a);
121121
auto bv = dm.getDependent(builder, _b);
122122

123-
auto aTyp = av.getType().cast<::imex::ptensor::PTensorType>();
124-
auto bTyp = bv.getType().cast<::imex::ptensor::PTensorType>();
123+
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
124+
auto bTyp = bv.getType().cast<::imex::ndarray::NDArrayType>();
125125
auto outElemType =
126-
::imex::ptensor::toMLIR(builder, SHARPY::jit::getPTDType(_dtype));
126+
::imex::ndarray::toMLIR(builder, SHARPY::jit::getPTDType(_dtype));
127127
auto outTyp = aTyp.cloneWith(shape(), outElemType);
128128

129129
::mlir::Value one, two;
@@ -134,7 +134,7 @@ struct DeferredEWBinOp : public Deferred {
134134
one = av;
135135
two = bv;
136136
}
137-
auto bop = builder.create<::imex::ptensor::EWBinOp>(
137+
auto bop = builder.create<::imex::ndarray::EWBinOp>(
138138
loc, outTyp, builder.getI32IntegerAttr(sharpy2mlir(_op)), one, two);
139139

140140
dm.addVal(this->guid(), bop,

src/EWUnyOp.cpp

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ namespace x {
2020
class EWUnyOp
2121
{
2222
public:
23-
using ptr_type = DPTensorBaseX::ptr_type;
23+
using ptr_type = DNDArrayBaseX::ptr_type;
2424

2525
template<typename T>
26-
static ptr_type op(EWUnyOpId uop, const std::shared_ptr<DPTensorX<T>> & a_ptr)
26+
static ptr_type op(EWUnyOpId uop, const std::shared_ptr<DNDArrayX<T>> & a_ptr)
2727
{
2828
const auto & ax = a_ptr->xarray();
2929
if(a_ptr->is_sliced()) {
@@ -35,7 +35,7 @@ namespace x {
3535

3636
#pragma GCC diagnostic ignored "-Wswitch"
3737
template<typename T1, typename T>
38-
static ptr_type do_op(EWUnyOpId uop, const T1 & a, const std::shared_ptr<DPTensorX<T>> & a_ptr)
38+
static ptr_type do_op(EWUnyOpId uop, const T1 & a, const std::shared_ptr<DNDArrayX<T>> & a_ptr)
3939
{
4040
switch(uop) {
4141
case __ABS__:
@@ -122,72 +122,72 @@ namespace x {
122122
} //namespace x
123123
#endif // if 0
124124

125-
// convert id of our unary op to id of imex::ptensor unary op
126-
static ::imex::ptensor::EWUnyOpId sharpy(const EWUnyOpId uop) {
125+
// convert id of our unary op to id of imex::ndarray unary op
126+
static ::imex::ndarray::EWUnyOpId sharpy(const EWUnyOpId uop) {
127127
switch (uop) {
128128
case __ABS__:
129129
case ABS:
130-
return ::imex::ptensor::ABS;
130+
return ::imex::ndarray::ABS;
131131
case ACOS:
132-
return ::imex::ptensor::ACOS;
132+
return ::imex::ndarray::ACOS;
133133
case ACOSH:
134-
return ::imex::ptensor::ACOSH;
134+
return ::imex::ndarray::ACOSH;
135135
case ASIN:
136-
return ::imex::ptensor::ASIN;
136+
return ::imex::ndarray::ASIN;
137137
case ASINH:
138-
return ::imex::ptensor::ASINH;
138+
return ::imex::ndarray::ASINH;
139139
case ATAN:
140-
return ::imex::ptensor::ATAN;
140+
return ::imex::ndarray::ATAN;
141141
case ATANH:
142-
return ::imex::ptensor::ATANH;
142+
return ::imex::ndarray::ATANH;
143143
case CEIL:
144-
return ::imex::ptensor::CEIL;
144+
return ::imex::ndarray::CEIL;
145145
case COS:
146-
return ::imex::ptensor::COS;
146+
return ::imex::ndarray::COS;
147147
case COSH:
148-
return ::imex::ptensor::COSH;
148+
return ::imex::ndarray::COSH;
149149
case EXP:
150-
return ::imex::ptensor::EXP;
150+
return ::imex::ndarray::EXP;
151151
case EXPM1:
152-
return ::imex::ptensor::EXPM1;
152+
return ::imex::ndarray::EXPM1;
153153
case FLOOR:
154-
return ::imex::ptensor::FLOOR;
154+
return ::imex::ndarray::FLOOR;
155155
case ISFINITE:
156-
return ::imex::ptensor::ISFINITE;
156+
return ::imex::ndarray::ISFINITE;
157157
case ISINF:
158-
return ::imex::ptensor::ISINF;
158+
return ::imex::ndarray::ISINF;
159159
case ISNAN:
160-
return ::imex::ptensor::ISNAN;
160+
return ::imex::ndarray::ISNAN;
161161
case LOG:
162-
return ::imex::ptensor::LOG;
162+
return ::imex::ndarray::LOG;
163163
case LOG1P:
164-
return ::imex::ptensor::LOG1P;
164+
return ::imex::ndarray::LOG1P;
165165
case LOG2:
166-
return ::imex::ptensor::LOG2;
166+
return ::imex::ndarray::LOG2;
167167
case LOG10:
168-
return ::imex::ptensor::LOG10;
168+
return ::imex::ndarray::LOG10;
169169
case ROUND:
170-
return ::imex::ptensor::ROUND;
170+
return ::imex::ndarray::ROUND;
171171
case SIGN:
172-
return ::imex::ptensor::SIGN;
172+
return ::imex::ndarray::SIGN;
173173
case SIN:
174-
return ::imex::ptensor::SIN;
174+
return ::imex::ndarray::SIN;
175175
case SINH:
176-
return ::imex::ptensor::SINH;
176+
return ::imex::ndarray::SINH;
177177
case SQUARE:
178-
return ::imex::ptensor::SQUARE;
178+
return ::imex::ndarray::SQUARE;
179179
case SQRT:
180-
return ::imex::ptensor::SQRT;
180+
return ::imex::ndarray::SQRT;
181181
case TAN:
182-
return ::imex::ptensor::TAN;
182+
return ::imex::ndarray::TAN;
183183
case TANH:
184-
return ::imex::ptensor::TANH;
184+
return ::imex::ndarray::TANH;
185185
case TRUNC:
186-
return ::imex::ptensor::TRUNC;
186+
return ::imex::ndarray::TRUNC;
187187
case ERF:
188-
return ::imex::ptensor::ERF;
188+
return ::imex::ndarray::ERF;
189189
case LOGICAL_NOT:
190-
return ::imex::ptensor::LOGICAL_NOT;
190+
return ::imex::ndarray::LOGICAL_NOT;
191191
case __NEG__:
192192
case NEGATIVE:
193193
case __POS__:
@@ -210,10 +210,10 @@ struct DeferredEWUnyOp : public Deferred {
210210
jit::DepManager &dm) override {
211211
auto av = dm.getDependent(builder, _a);
212212

213-
auto aTyp = av.getType().cast<::imex::ptensor::PTensorType>();
213+
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
214214
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());
215215

216-
auto uop = builder.create<::imex::ptensor::EWUnyOp>(
216+
auto uop = builder.create<::imex::ndarray::EWUnyOp>(
217217
loc, outTyp, builder.getI32IntegerAttr(sharpy(_op)), av);
218218

219219
dm.addVal(this->guid(), uop,

0 commit comments

Comments
 (0)