diff --git a/jitify2.hpp b/jitify2.hpp index a32e001..76599cb 100644 --- a/jitify2.hpp +++ b/jitify2.hpp @@ -1723,6 +1723,7 @@ class LibCuda int, CUkernel, CUdevice) JITIFY_DEFINE_CUDA_WRAPPER(KernelGetAttribute, CUresult, int*, CUfunction_attribute, CUkernel, CUdevice) + JITIFY_DEFINE_CUDA_WRAPPER(KernelGetFunction, CUresult, CUfunction*, CUkernel) #endif #undef JITIFY_DEFINE_CUDA_WRAPPER #undef JITIFY_STR @@ -2562,8 +2563,18 @@ inline ConfiguredKernel ConfiguredKernel::configure_1d_max_occupancy( unsigned int flags) { int grid, block; if (!cuda()) return Error(cuda().error()); - CUresult ret = cuda().OccupancyMaxPotentialBlockSizeWithFlags()( - &grid, &block, (CUfunction)kernel.function(), + CUfunction cu_func = NULL; + CUresult ret; +#if JITIFY_USE_CONTEXT_INDEPENDENT_LOADING + ret = cuda().KernelGetFunction()(&cu_func, (CUkernel)kernel.function()); + if (ret != CUDA_SUCCESS) { + return Error("Configure failed getting Kernel Function: " + detail::get_cuda_error_string(ret)); + } +#else + cu_func = (CUfunction)kernel.function(), +#endif + ret = cuda().OccupancyMaxPotentialBlockSizeWithFlags()( + &grid, &block, (CUfunction)cu_func, shared_memory_bytes_callback, shared_memory_bytes, max_block_size, flags); if (ret != CUDA_SUCCESS) { return Error("Configure failed: " + detail::get_cuda_error_string(ret));