diff --git a/extension_cpp/extension_cpp/csrc/muladd.cpp b/extension_cpp/extension_cpp/csrc/muladd.cpp index d68332a..30e5667 100644 --- a/extension_cpp/extension_cpp/csrc/muladd.cpp +++ b/extension_cpp/extension_cpp/csrc/muladd.cpp @@ -10,7 +10,7 @@ extern "C" { The import from Python will load the .so consisting of this file in this extension, so that the TORCH_LIBRARY static initializers below are run. */ - PyObject* PyInit__C(void) + PyMODINIT_FUNC PyInit__C(void) { static struct PyModuleDef module_def = { PyModuleDef_HEAD_INIT, diff --git a/extension_cpp/setup.py b/extension_cpp/setup.py index 33a2a99..6ea2785 100644 --- a/extension_cpp/setup.py +++ b/extension_cpp/setup.py @@ -33,6 +33,14 @@ def get_extensions(): use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None extension = CUDAExtension if use_cuda else CppExtension + NVCC_FLAGS = [ + "-O3" if not debug_mode else "-O0", + ] + if os.name == "nt": + NVCC_FLAGS += [ + "-DUSE_CUDA=1", + ] + extra_link_args = [] extra_compile_args = { "cxx": [ @@ -40,9 +48,7 @@ def get_extensions(): "-fdiagnostics-color=always", "-DPy_LIMITED_API=0x03090000", # min CPython version 3.9 ], - "nvcc": [ - "-O3" if not debug_mode else "-O0", - ], + "nvcc": NVCC_FLAGS, } if debug_mode: extra_compile_args["cxx"].append("-g") diff --git a/extension_cpp_stable/extension_cpp_stable/csrc/muladd.cpp b/extension_cpp_stable/extension_cpp_stable/csrc/muladd.cpp index c843c58..6d83d6a 100644 --- a/extension_cpp_stable/extension_cpp_stable/csrc/muladd.cpp +++ b/extension_cpp_stable/extension_cpp_stable/csrc/muladd.cpp @@ -11,7 +11,7 @@ extern "C" { The import from Python will load the .so consisting of this file in this extension, so that the STABLE_TORCH_LIBRARY static initializers below are run. */ - PyObject* PyInit__C(void) + PyMODINIT_FUNC PyInit__C(void) { static struct PyModuleDef module_def = { PyModuleDef_HEAD_INIT,