diff --git a/jitify.hpp b/jitify.hpp index 21692e4..68d18d7 100644 --- a/jitify.hpp +++ b/jitify.hpp @@ -687,6 +687,73 @@ inline bool load_source( comment; } + // WAR where nvrtc can fail to correctly return if an include is present + size_t has_include_start = cleanline.find("__has_include"); + if (has_include_start != std::string::npos) { + // find subsequent opening and closing braces + size_t open = cleanline.find("(", has_include_start); + size_t close = cleanline.find(")", open); + if (!(open == std::string::npos || close == std::string::npos)) { + std::string has_name = cleanline.substr(open + 1, close - open - 1); + + if (has_name.find("<") == std::string::npos && has_name.find("\"") == std::string::npos) + throw std::runtime_error("Malformed __has_include statement (" + + filename + ":" + std::to_string(linenum) + + ")"); + // are we using quote includes or angle brackets? + // (test using angle brackets, since quotes are valid around angle brackets) + bool quote_include = has_name.find("<") == std::string::npos; + size_t header_start = (quote_include ? has_name.find("\"") : has_name.find("<")) + 1; + size_t header_count = has_name.find(quote_include ? "\"" : ">", header_start) - header_start; + if (has_name.find(quote_include ? "\"" : ">", header_start) == std::string::npos) + throw std::runtime_error("Malformed __has_include statement (" + + filename + ":" + std::to_string(linenum) + + ")"); + + if (header_count != 0) { + std::string has_include_name = + has_name.substr(header_start, header_count); + +#if JITIFY_PRINT_HEADER_PATHS + std::cout << "Found #if __has_include(" << has_name << ")" + << " from " << filename << ":" << linenum << std::endl; +#endif + // Try loading from filesystem + bool found_file = false; + std::string has_include_fullpath = + path_join(current_dir, has_include_name); + if (quote_include) { + file_stream.open(has_include_fullpath.c_str()); + if (file_stream) found_file = true; + } + // Search include directories + if (!found_file) { + for (int i = 0; i < (int)include_paths.size(); ++i) { + has_include_fullpath = + path_join(include_paths[i], has_include_name); + file_stream.open(has_include_fullpath.c_str()); + if (file_stream) { + found_file = true; + break; + } + } + if (!found_file) { + // Try loading from builtin headers + has_include_fullpath = + path_join("__jitify_builtin", has_include_name); + auto it = get_jitsafe_headers_map().find(has_include_name); + if (it != get_jitsafe_headers_map().end()) { + found_file = true; + } + } + } + + line = cleanline.substr(0, has_include_start) + (found_file ? "(1)" : "(0)") + + cleanline.substr(close + 1); + } + } + } + source += line + "\n"; } // HACK TESTING (WAR for cub) diff --git a/jitify_test.cu b/jitify_test.cu index 4a79806..46d0ac7 100644 --- a/jitify_test.cu +++ b/jitify_test.cu @@ -961,8 +961,6 @@ static const char* const builtin_numeric_cuda_std_limits_program_source = "builtin_numeric_cuda_std_limits_program\n" "#include \n" "#include \n" - "#include \n" // test fails without this explicit include - "#include \n" "struct MyType {};\n" "namespace cuda {\n" "namespace std {\n" @@ -1138,6 +1136,62 @@ TEST(JitifyTest, EnvVarOptions) { setenv("JITIFY_OPTIONS", "", true); } +static const char* const has_include_source = R"( + #if __has_include() + #else + #error __has_include failed + #endif + + #if 1 && __has_include() + #else + #error __has_include failed + #endif + + #if __has_include() && 0 + #error __has_include failed + #else + #endif + + #if __has_include("") + #else + #error __has_include failed + #endif + + #if __has_include("limits") + #else + #error __has_include failed + #endif + + #if __has_include("example_headers/my_header1.cuh") + #else + #error __has_include failed + #endif + + // check we don't touch these + #if defined(__has_include) + #else + #error __has_include failed + #endif + + #if !defined(__has_include) + #error __has_include failed + #endif + + __global__ void has_include_kernel() { } +)"; + +TEST(JitifyTest, HasInclude) { + // Checks that cassert works as expected + jitify::JitCache kernel_cache; + auto program = kernel_cache.program(has_include_source); + dim3 grid(1); + dim3 block(1); + CHECK_CUDA((program.kernel("has_include_kernel") + .instantiate<>() + .configure(grid, block) + .launch())); +} + // NOTE: This MUST be the last test in the file, due to sticky CUDA error. TEST(JitifyTest, AssertHeader) { // Checks that cassert works as expected