Skip to content

Commit b8b8f0b

Browse files
committed
API for arith header from ArrayFire
1 parent 0fc739b commit b8b8f0b

File tree

4 files changed

+235
-48
lines changed

4 files changed

+235
-48
lines changed

src/algorithm/mod.rs

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,69 +11,39 @@ type AfArray = self::libc::c_longlong;
1111
#[allow(dead_code)]
1212
extern {
1313
fn af_sum(out: MutAfArray, input: AfArray, dim: c_int) -> c_int;
14-
1514
//fn af_sum_nan(out: MutAfArray, input: AfArray, dim: c_int, nanval: c_double) -> c_int;
16-
1715
fn af_product(out: MutAfArray, input: AfArray, dim: c_int) -> c_int;
18-
1916
//fn af_product_nan(out: MutAfArray, input: AfArray, dim: c_int, val: c_double) -> c_int;
20-
2117
fn af_min(out: MutAfArray, input: AfArray, dim: c_int) -> c_int;
22-
2318
fn af_max(out: MutAfArray, input: AfArray, dim: c_int) -> c_int;
24-
2519
fn af_all_true(out: MutAfArray, input: AfArray, dim: c_int) -> c_int;
26-
2720
fn af_any_true(out: MutAfArray, input: AfArray, dim: c_int) -> c_int;
28-
2921
fn af_count(out: MutAfArray, input: AfArray, dim: c_int) -> c_int;
30-
3122
fn af_sum_all(r: MutDouble, i: MutDouble, input: AfArray) -> c_int;
32-
3323
//fn af_sum_nan_all(r: MutDouble, i: MutDouble, input: AfArray, val: c_double) -> c_int;
34-
3524
fn af_product_all(r: MutDouble, i: MutDouble, input: AfArray) -> c_int;
36-
3725
//fn af_product_nan_all(r: MutDouble, i: MutDouble, input: AfArray, val: c_double) -> c_int;
38-
3926
fn af_min_all(r: MutDouble, i: MutDouble, input: AfArray) -> c_int;
40-
4127
fn af_max_all(r: MutDouble, i: MutDouble, input: AfArray) -> c_int;
42-
4328
fn af_all_true_all(r: MutDouble, i: MutDouble, input: AfArray) -> c_int;
44-
4529
fn af_any_true_all(r: MutDouble, i: MutDouble, input: AfArray) -> c_int;
46-
4730
fn af_count_all(r: MutDouble, i: MutDouble, input: AfArray) -> c_int;
48-
4931
fn af_imin(out: MutAfArray, idx: MutAfArray, input: AfArray, dim: c_int) -> c_int;
50-
5132
fn af_imax(out: MutAfArray, idx: MutAfArray, input: AfArray, dim: c_int) -> c_int;
52-
5333
fn af_imin_all(r: MutDouble, i: MutDouble, idx: MutUint, input: AfArray) -> c_int;
54-
5534
fn af_imax_all(r: MutDouble, i: MutDouble, idx: MutUint, input: AfArray) -> c_int;
56-
5735
fn af_accum(out: MutAfArray, input: AfArray, dim: c_int) -> c_int;
58-
5936
fn af_where(out: MutAfArray, input: AfArray) -> c_int;
60-
6137
fn af_diff1(out: MutAfArray, input: AfArray, dim: c_int) -> c_int;
62-
6338
fn af_diff2(out: MutAfArray, input: AfArray, dim: c_int) -> c_int;
64-
6539
fn af_sort(out: MutAfArray, input: AfArray, dim: c_uint, ascend: c_int) -> c_int;
66-
6740
fn af_sort_index(o: MutAfArray, i: MutAfArray, inp: AfArray, d: c_uint, a: c_int) -> c_int;
68-
69-
fn af_sort_by_key(out_keys: MutAfArray, out_vals: MutAfArray,
70-
in_keys: AfArray, in_vals: AfArray, dim: c_uint, ascend: c_int) -> c_int;
71-
7241
fn af_set_unique(out: MutAfArray, input: AfArray, is_sorted: c_int) -> c_int;
73-
7442
fn af_set_union(out: MutAfArray, first: AfArray, second: AfArray, is_unq: c_int) -> c_int;
75-
7643
fn af_set_intersect(out: MutAfArray, one: AfArray, two: AfArray, is_unq: c_int) -> c_int;
44+
45+
fn af_sort_by_key(out_keys: MutAfArray, out_vals: MutAfArray,
46+
in_keys: AfArray, in_vals: AfArray, dim: c_uint, ascend: c_int) -> c_int;
7747
}
7848

7949
#[allow(unused_mut)]

src/arith/mod.rs

Lines changed: 226 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,256 @@
11
extern crate libc;
2+
extern crate num;
23

34
use super::Array as Array;
45
use self::libc::{c_int};
56
use data::constant;
7+
use self::num::Complex;
68

79
type MutAfArray = *mut self::libc::c_longlong;
810
type MutDouble = *mut self::libc::c_double;
911
type MutUint = *mut self::libc::c_uint;
1012
type AfArray = self::libc::c_longlong;
1113

12-
use std::ops::Add;
14+
use std::ops::{Add, Sub, Div, Mul, BitAnd, BitOr, BitXor, Not, Rem, Shl, Shr};
1315

1416
#[allow(dead_code)]
1517
extern {
1618
fn af_add(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
19+
fn af_sub(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
20+
fn af_mul(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
21+
fn af_div(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
22+
23+
fn af_lt(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
24+
fn af_gt(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
25+
fn af_le(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
26+
fn af_ge(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
27+
fn af_eq(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
28+
fn af_or(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
29+
30+
fn af_neq(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
31+
fn af_and(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
32+
fn af_rem(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
33+
fn af_mod(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
34+
35+
fn af_bitand(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
36+
fn af_bitor(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
37+
fn af_bitxor(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
38+
fn af_bitshiftl(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
39+
fn af_bitshiftr(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
40+
fn af_minof(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
41+
fn af_maxof(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
42+
43+
fn af_not(out: MutAfArray, arr: AfArray) -> c_int;
44+
fn af_abs(out: MutAfArray, arr: AfArray) -> c_int;
45+
fn af_arg(out: MutAfArray, arr: AfArray) -> c_int;
46+
fn af_sign(out: MutAfArray, arr: AfArray) -> c_int;
47+
fn af_ceil(out: MutAfArray, arr: AfArray) -> c_int;
48+
fn af_round(out: MutAfArray, arr: AfArray) -> c_int;
49+
fn af_trunc(out: MutAfArray, arr: AfArray) -> c_int;
50+
fn af_floor(out: MutAfArray, arr: AfArray) -> c_int;
51+
52+
fn af_hypot(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
1753

1854
fn af_sin(out: MutAfArray, arr: AfArray) -> c_int;
55+
fn af_cos(out: MutAfArray, arr: AfArray) -> c_int;
56+
fn af_tan(out: MutAfArray, arr: AfArray) -> c_int;
57+
fn af_asin(out: MutAfArray, arr: AfArray) -> c_int;
58+
fn af_acos(out: MutAfArray, arr: AfArray) -> c_int;
59+
fn af_atan(out: MutAfArray, arr: AfArray) -> c_int;
60+
61+
fn af_atan2(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
62+
fn af_cplx2(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
63+
fn af_root(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
64+
fn af_pow(out: MutAfArray, lhs: AfArray, rhs: AfArray, batch: c_int) -> c_int;
1965

66+
fn af_cplx(out: MutAfArray, arr: AfArray) -> c_int;
67+
fn af_real(out: MutAfArray, arr: AfArray) -> c_int;
68+
fn af_imag(out: MutAfArray, arr: AfArray) -> c_int;
69+
fn af_conjg(out: MutAfArray, arr: AfArray) -> c_int;
70+
fn af_sinh(out: MutAfArray, arr: AfArray) -> c_int;
71+
fn af_cosh(out: MutAfArray, arr: AfArray) -> c_int;
72+
fn af_tanh(out: MutAfArray, arr: AfArray) -> c_int;
73+
fn af_asinh(out: MutAfArray, arr: AfArray) -> c_int;
74+
fn af_acosh(out: MutAfArray, arr: AfArray) -> c_int;
75+
fn af_atanh(out: MutAfArray, arr: AfArray) -> c_int;
76+
fn af_pow2(out: MutAfArray, arr: AfArray) -> c_int;
77+
fn af_exp(out: MutAfArray, arr: AfArray) -> c_int;
78+
fn af_expm1(out: MutAfArray, arr: AfArray) -> c_int;
79+
fn af_erf(out: MutAfArray, arr: AfArray) -> c_int;
80+
fn af_erfc(out: MutAfArray, arr: AfArray) -> c_int;
81+
fn af_log(out: MutAfArray, arr: AfArray) -> c_int;
82+
fn af_log1p(out: MutAfArray, arr: AfArray) -> c_int;
83+
fn af_log10(out: MutAfArray, arr: AfArray) -> c_int;
84+
fn af_log2(out: MutAfArray, arr: AfArray) -> c_int;
85+
fn af_sqrt(out: MutAfArray, arr: AfArray) -> c_int;
86+
fn af_cbrt(out: MutAfArray, arr: AfArray) -> c_int;
87+
fn af_factorial(out: MutAfArray, arr: AfArray) -> c_int;
88+
fn af_tgamma(out: MutAfArray, arr: AfArray) -> c_int;
89+
fn af_lgamma(out: MutAfArray, arr: AfArray) -> c_int;
90+
fn af_iszero(out: MutAfArray, arr: AfArray) -> c_int;
91+
fn af_isinf(out: MutAfArray, arr: AfArray) -> c_int;
92+
fn af_isnan(out: MutAfArray, arr: AfArray) -> c_int;
2093
}
2194

22-
impl Add<f64> for Array {
95+
impl Not for Array {
2396
type Output = Array;
2497

25-
fn add(self, rhs: f64) -> Array {
26-
let cnst_arr = constant(rhs, self.dims());
98+
fn not(self) -> Array {
2799
unsafe {
28100
let mut temp: i64 = 0;
29-
af_add(&mut temp as MutAfArray, self.get() as AfArray, cnst_arr.get() as AfArray, 0);
101+
af_not(&mut temp as MutAfArray, self.get() as AfArray);
30102
Array {handle: temp}
31103
}
32104
}
33105
}
34106

35-
#[allow(unused_mut)]
36-
pub fn sin(input: &Array) -> Array {
37-
unsafe {
38-
let mut temp: i64 = 0;
39-
af_sin(&mut temp as MutAfArray, input.get() as AfArray);
40-
Array {handle: temp}
41-
}
107+
macro_rules! unary_func {
108+
($fn_name: ident, $ffi_fn: ident) => (
109+
#[allow(unused_mut)]
110+
pub fn $fn_name(input: &Array) -> Array {
111+
unsafe {
112+
let mut temp: i64 = 0;
113+
$ffi_fn(&mut temp as MutAfArray, input.get() as AfArray);
114+
Array {handle: temp}
115+
}
116+
}
117+
)
118+
}
119+
120+
unary_func!(abs, af_abs);
121+
unary_func!(arg, af_arg);
122+
unary_func!(sign, af_sign);
123+
unary_func!(round, af_round);
124+
unary_func!(trunc, af_trunc);
125+
unary_func!(floor, af_floor);
126+
unary_func!(ceil, af_ceil);
127+
unary_func!(sin, af_sin);
128+
unary_func!(cos, af_cos);
129+
unary_func!(tan, af_tan);
130+
unary_func!(asin, af_asin);
131+
unary_func!(acos, af_acos);
132+
unary_func!(atan, af_atan);
133+
unary_func!(cplx, af_cplx);
134+
unary_func!(real, af_real);
135+
unary_func!(imag, af_imag);
136+
unary_func!(conjg, af_conjg);
137+
unary_func!(sinh, af_sinh);
138+
unary_func!(cosh, af_cosh);
139+
unary_func!(tanh, af_tanh);
140+
unary_func!(asinh, af_asinh);
141+
unary_func!(acosh, af_acosh);
142+
unary_func!(atanh, af_atanh);
143+
unary_func!(pow2, af_pow2);
144+
unary_func!(exp, af_exp);
145+
unary_func!(expm1, af_expm1);
146+
unary_func!(erf, af_erf);
147+
unary_func!(erfc, af_erfc);
148+
unary_func!(log, af_log);
149+
unary_func!(log1p, af_log1p);
150+
unary_func!(log10, af_log10);
151+
unary_func!(log2, af_log2);
152+
unary_func!(sqrt, af_sqrt);
153+
unary_func!(cbrt, af_cbrt);
154+
unary_func!(factorial, af_factorial);
155+
unary_func!(tgamma, af_tgamma);
156+
unary_func!(lgamma, af_lgamma);
157+
unary_func!(iszero, af_iszero);
158+
unary_func!(isinf, af_isinf);
159+
unary_func!(isnan, af_isnan);
160+
161+
macro_rules! binary_func {
162+
($fn_name: ident, $ffi_fn: ident) => (
163+
#[allow(unused_mut)]
164+
pub fn $fn_name(lhs: &Array, rhs: &Array) -> Array {
165+
unsafe {
166+
let mut temp: i64 = 0;
167+
$ffi_fn(&mut temp as MutAfArray, lhs.get() as AfArray, rhs.get() as AfArray, 0);
168+
Array {handle: temp}
169+
}
170+
}
171+
)
42172
}
173+
174+
binary_func!(lt, af_lt);
175+
binary_func!(gt, af_gt);
176+
binary_func!(le, af_le);
177+
binary_func!(ge, af_ge);
178+
binary_func!(eq, af_eq);
179+
binary_func!(neq, af_neq);
180+
binary_func!(and, af_and);
181+
binary_func!(or, af_or);
182+
binary_func!(minof, af_minof);
183+
binary_func!(maxof, af_maxof);
184+
binary_func!(modulo, af_mod);
185+
binary_func!(hypot, af_hypot);
186+
binary_func!(atan2, af_atan2);
187+
binary_func!(cplx2, af_cplx2);
188+
binary_func!(root, af_root);
189+
binary_func!(pow, af_pow);
190+
191+
macro_rules! arith_scalar_func {
192+
($rust_type: ty, $op_name:ident, $fn_name: ident, $ffi_fn: ident) => (
193+
impl $op_name<$rust_type> for Array {
194+
type Output = Array;
195+
196+
fn $fn_name(self, rhs: $rust_type) -> Array {
197+
let cnst_arr = constant(rhs, self.dims());
198+
unsafe {
199+
let mut temp: i64 = 0;
200+
$ffi_fn(&mut temp as MutAfArray,
201+
self.get() as AfArray, cnst_arr.get() as AfArray,
202+
0);
203+
Array {handle: temp}
204+
}
205+
}
206+
}
207+
)
208+
}
209+
210+
macro_rules! arith_scalar_spec {
211+
($ty_name:ty) => (
212+
arith_scalar_func!($ty_name, Add, add, af_add);
213+
arith_scalar_func!($ty_name, Sub, sub, af_sub);
214+
arith_scalar_func!($ty_name, Mul, mul, af_mul);
215+
arith_scalar_func!($ty_name, Div, div, af_div);
216+
)
217+
}
218+
219+
arith_scalar_spec!(Complex<f64>);
220+
arith_scalar_spec!(Complex<f32>);
221+
arith_scalar_spec!(f64);
222+
arith_scalar_spec!(f32);
223+
arith_scalar_spec!(u64);
224+
arith_scalar_spec!(i64);
225+
arith_scalar_spec!(u32);
226+
arith_scalar_spec!(i32);
227+
arith_scalar_spec!(u8);
228+
229+
macro_rules! arith_func {
230+
($op_name:ident, $fn_name:ident, $ffi_fn: ident) => (
231+
impl $op_name<Array> for Array {
232+
type Output = Array;
233+
234+
fn $fn_name(self, rhs: Array) -> Array {
235+
unsafe {
236+
let mut temp: i64 = 0;
237+
$ffi_fn(&mut temp as MutAfArray,
238+
self.get() as AfArray, rhs.get() as AfArray,
239+
0);
240+
Array {handle: temp}
241+
}
242+
}
243+
}
244+
)
245+
}
246+
247+
arith_func!(Add, add, af_add);
248+
arith_func!(Sub, sub, af_sub);
249+
arith_func!(Mul, mul, af_mul);
250+
arith_func!(Div, div, af_div);
251+
arith_func!(Rem, rem, af_rem);
252+
arith_func!(BitAnd, bitand, af_bitand);
253+
arith_func!(BitOr, bitor, af_bitor);
254+
arith_func!(BitXor, bitxor, af_bitxor);
255+
arith_func!(Shl, shl, af_bitshiftl);
256+
arith_func!(Shr, shr, af_bitshiftr);

src/data/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@ extern {
3737
t_ndims: c_uint, tdims: *const DimT, afdtype: c_int) -> c_int;
3838

3939
fn af_randu(out: MutAfArray, ndims: c_uint, dims: *const DimT, afdtype: c_int) -> c_int;
40-
4140
fn af_randn(out: MutAfArray, ndims: c_uint, dims: *const DimT, afdtype: c_int) -> c_int;
4241

4342
fn af_set_seed(seed: Uintl);
44-
4543
fn af_get_seed(seed: *mut Uintl);
4644

4745
}

src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ pub use algorithm::{accum, locate, diff1, diff2, sort, sort_index, sort_by_key};
3939
pub use algorithm::{set_unique, set_union, set_intersect};
4040
mod algorithm;
4141

42-
pub use arith::{sin};
42+
pub use arith::{lt, gt, le, ge, eq, neq, and, or, minof, maxof};
43+
pub use arith::{abs, sign, round, trunc, floor, ceil, modulo};
44+
pub use arith::{sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh};
45+
pub use arith::{atan2, cplx2, arg, cplx, real, imag, conjg, hypot};
46+
pub use arith::{sqrt, log, log1p, log10, log2, pow2, exp, expm1, erf, erfc, root, pow};
47+
pub use arith::{cbrt, factorial, tgamma, lgamma, iszero, isinf, isnan};
4348
mod arith;
4449

4550
pub use data::{constant, range, iota};

0 commit comments

Comments
 (0)