Skip to content

Commit 27f8713

Browse files
Add dynamic shared memory allocation
1 parent 17b1fb7 commit 27f8713

8 files changed

Lines changed: 31 additions & 9 deletions

File tree

platforms/artic/intrinsics_thorin.impala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
#[import(cc = "thorin")] fn cmpxchg_weak[T](_addr: &mut T, _cmp: T, _new: T, _success_order: u32, _failure_order: u32, _scope: &[u8]) -> (T, bool); // only for integer data types
1313
#[import(cc = "thorin")] fn fence(_order: u32, _scope: &[u8]) -> ();
1414
#[import(cc = "thorin")] fn pe_info[T](_src: &[u8], _val: T) -> ();
15-
#[import(cc = "thorin")] fn cuda(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
16-
#[import(cc = "thorin")] fn nvvm(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
17-
#[import(cc = "thorin")] fn opencl(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
18-
#[import(cc = "thorin")] fn amdgpu(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> ();
15+
#[import(cc = "thorin", name = "cuda")] fn cuda_with_shared(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _smem: i32, _body: fn() -> ()) -> ();
16+
#[import(cc = "thorin", name = "nvvm")] fn nvvm_with_shared(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _smem: i32, _body: fn() -> ()) -> ();
17+
#[import(cc = "thorin", name = "opencl")] fn opencl_with_shared(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _smem: i32, _body: fn() -> ()) -> ();
18+
#[import(cc = "thorin", name = "amdgpu")] fn amdgpu_with_shared(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _smem: i32, _body: fn() -> ()) -> ();
1919
#[import(cc = "thorin")] fn reserve_shared[T](_size: i32) -> &mut addrspace(3)[T];
2020
#[import(cc = "thorin")] fn hls(_dev: i32, _body: fn() -> ()) -> ();
2121
#[import(cc = "thorin", name = "pipeline")] fn thorin_pipeline(_initiation_interval: i32, _lower: i32, _upper: i32, _body: fn(i32) -> ()) -> (); // only for HLS/OpenCL backend
@@ -35,6 +35,11 @@
3535
#[import(cc = "thorin", name = "cmpxchg_weak")] fn cmpxchg_weak_p1[T](_addr: &mut addrspace(1)T, _cmp: T, _new: T, _success_order: u32, _failure_order: u32, _scope: &[u8]) -> (T, bool);
3636
#[import(cc = "thorin", name = "cmpxchg_weak")] fn cmpxchg_weak_p3[T](_addr: &mut addrspace(3)T, _cmp: T, _new: T, _success_order: u32, _failure_order: u32, _scope: &[u8]) -> (T, bool);
3737

38+
fn @cuda(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = cuda_with_shared(dev, grid, block, 0, body);
39+
fn @nvvm(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = nvvm_with_shared(dev, grid, block, 0, body);
40+
fn @opencl(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = opencl_with_shared(dev, grid, block, 0, body);
41+
fn @amdgpu(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) = amdgpu_with_shared(dev, grid, block, 0, body);
42+
3843
fn @pipeline(body: fn(i32) -> ()) = @|initiation_interval: i32, lower: i32, upper: i32| thorin_pipeline(initiation_interval, lower, upper, body);
3944
fn @parallel(body: fn(i32) -> ()) = @|num_threads: i32, lower: i32, upper: i32| thorin_parallel(num_threads, lower, upper, body);
4045
fn @spawn(body: fn() -> ()) = @|| thorin_spawn(body);

platforms/impala/intrinsics_thorin.impala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ extern "thorin" {
1010
fn insert[T, U](T, i32, U) -> T;
1111
//fn shuffle[T](T, T, T) -> T;
1212

13-
fn cuda(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
14-
fn nvvm(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
15-
fn opencl(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
16-
fn amdgpu(i32, (i32, i32, i32), (i32, i32, i32), fn() -> ()) -> ();
13+
fn "cuda" cuda_with_shared(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
14+
fn "nvvm" nvvm_with_shared(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
15+
fn "opencl" opencl_with_shared(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
16+
fn "amdgpu" amdgpu_with_shared(i32, (i32, i32, i32), (i32, i32, i32), i32, fn() -> ()) -> ();
1717
fn reserve_shared[T](i32) -> &mut[3][T];
1818

1919
fn hls(dev: i32, body: fn() -> ()) -> ();
@@ -42,3 +42,8 @@ extern "thorin" {
4242

4343
fn vectorize(vector_length: i32, body: fn(i32) -> ()) -> ();
4444
}
45+
46+
fn @@cuda(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { cuda_with_shared(dev, grid, block, 0, body) }
47+
fn @@nvvm(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { nvvm_with_shared(dev, grid, block, 0, body) }
48+
fn @@opencl(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { opencl_with_shared(dev, grid, block, 0, body) }
49+
fn @@amdgpu(dev: i32, grid: (i32, i32, i32), block: (i32, i32, i32), body: fn() -> ()) { amdgpu_with_shared(dev, grid, block, 0, body) }

src/anydsl_runtime.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ void anydsl_copy(
117117
void anydsl_launch_kernel(
118118
int32_t mask, const char* file_name, const char* kernel_name,
119119
const uint32_t* grid, const uint32_t* block,
120+
uint32_t smem,
120121
void** arg_data,
121122
const uint32_t* arg_sizes,
122123
const uint32_t* arg_aligns,
@@ -128,6 +129,7 @@ void anydsl_launch_kernel(
128129
kernel_name,
129130
grid,
130131
block,
132+
smem,
131133
{
132134
arg_data,
133135
arg_sizes,

src/anydsl_runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ AnyDSL_runtime_API void anydsl_copy(int32_t, const void*, int64_t, int32_t, void
3636
AnyDSL_runtime_API void anydsl_launch_kernel(
3737
int32_t, const char*, const char*,
3838
const uint32_t*, const uint32_t*,
39+
uint32_t,
3940
void**, const uint32_t*, const uint32_t*, const uint32_t*, const uint8_t*,
4041
uint32_t);
4142
AnyDSL_runtime_API void anydsl_synchronize(int32_t);

src/cuda_platform.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ void CudaPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_params
225225
launch_params.grid[1] / launch_params.block[1],
226226
launch_params.grid[2] / launch_params.block[2],
227227
launch_params.block[0], launch_params.block[1], launch_params.block[2],
228-
0, nullptr, launch_params.args.data, nullptr);
228+
launch_params.smem,
229+
nullptr,
230+
launch_params.args.data, nullptr);
229231
CHECK_CUDA(err, "cuLaunchKernel()");
230232

231233
if (runtime_->profiling_enabled()) {

src/hsa_platform.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@ void HSAPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_params)
327327
if (!queue)
328328
error("The selected HSA device '%' cannot execute kernels", dev);
329329

330+
if (launch_params.smem != 0)
331+
error("Dynamic shared memory not yet supported on this platform");
332+
330333
auto& kernel_info = load_kernel(dev, launch_params.file_name, launch_params.kernel_name);
331334

332335
auto align_up = [&] (unsigned int start, unsigned int align) -> unsigned int {

src/opencl_platform.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,9 @@ void OpenCLPlatform::launch_kernel(DeviceId dev, const LaunchParams& launch_para
363363
return;
364364
}
365365

366+
if (launch_params.smem != 0)
367+
error("Dynamic shared memory not yet supported on this platform");
368+
366369
auto kernel = load_kernel(dev, launch_params.file_name, launch_params.kernel_name);
367370

368371
// set up arguments

src/runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ struct LaunchParams {
2727
const char* kernel_name;
2828
const uint32_t* grid;
2929
const uint32_t* block;
30+
uint32_t smem;
3031
struct {
3132
void** data;
3233
const uint32_t* sizes;

0 commit comments

Comments
 (0)