diff --git a/.gitignore b/.gitignore index 9f7128d495a9..8b13ab22b9bb 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,7 @@ third_party/build/ tools/coverage_plugins_package/pip-wheel-metadata/ tools/shared/_utils_internal.py tools/fast_nvcc/wrap_nvcc.sh +tools/fast_nvcc/wrap_nvcc.bat tools/fast_nvcc/tmp/ torch.egg-info/ torch/_C/__init__.pyi diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake index 7f45cd098447..839c43ea0482 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake @@ -768,13 +768,17 @@ endif() # FAST_NVCC if(USE_FAST_NVCC AND CUDA_NVCC_EXECUTABLE AND NOT CUDA_NVCC_EXECUTABLE_ORIGIN) set(CUDA_NVCC_EXECUTABLE_ORIGIN "${CUDA_NVCC_EXECUTABLE}") + set(EXTENSION "sh") + if (MSVC) + set(EXTENSION "bat") + endif() set(FAST_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/fast_nvcc.py") - configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh") - file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh" + configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}") + file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}" DESTINATION "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/" FILE_PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE ) - set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh") + set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}") endif() mark_as_advanced(CUDA_NVCC_EXECUTABLE) diff --git a/tools/fast_nvcc/fast_nvcc.py b/tools/fast_nvcc/fast_nvcc.py index 659d91ae3c1f..285a2032dfbd 100755 --- a/tools/fast_nvcc/fast_nvcc.py +++ b/tools/fast_nvcc/fast_nvcc.py @@ -78,6 +78,8 @@ url_vars = f"{url_base}#keeping-intermediate-phase-files" # regex for temporary file names re_tmp = r"(? None: @@ -141,7 +143,10 @@ def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData: print(result.stdout, end="") env = {} commands = [] - for line in result.stderr.splitlines(): + output = result.stderr + if os.name == "nt": + output = result.stdout + for line in output.splitlines(): match = re.match(r"^#\$ (.*)$", line) if match: (stripped,) = match.groups() @@ -213,9 +218,11 @@ def unique_module_id_files(commands: List[str]) -> List[str]: line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line) if arr: (filename,) = arr + if os.name == "nt": + filename = "%TEMP%\\" + filename if not module_id: module_id = module_id_contents(shlex.split(line)) - uniqueified.append(f"echo -n '{module_id}' > '{filename}'") + uniqueified.append(f"echo -n '{module_id}' > \"{filename}\"") uniqueified.append(line) return uniqueified @@ -261,6 +268,8 @@ def files_mentioned(command: str) -> List[str]: """ Return fully-qualified names of all tmp files referenced by command. """ + if os.name == "nt": + return [f"/%TEMP%/{match.group(1)}" for match in re.finditer(re_tmp, command)] return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)] @@ -294,7 +303,9 @@ def nvcc_data_dependencies(commands: List[str]) -> Graph: fatbins[i].add(tmp) else: tmp_files[tmp] = i - if line.startswith("rm ") and not deps: + if (line.startswith("rm ") or line.startswith("erase ")) and not deps: + if os.name == "nt": + commands[i] = line.replace("/", "\\") deps.add(i - 1) graph.append(deps) return graph @@ -421,6 +432,8 @@ async def run_graph( """ Return outputs/errors (and optionally time/file info) from commands. """ + if os.name == "nt": + env.update(os.environ.copy()) tasks: List[Awaitable[Result]] = [] for i, (command, indices) in enumerate(zip(commands, graph)): deps = {tasks[j] for j in indices} diff --git a/tools/fast_nvcc/wrap_nvcc.bat.in b/tools/fast_nvcc/wrap_nvcc.bat.in new file mode 100644 index 000000000000..f02a751e3a4f --- /dev/null +++ b/tools/fast_nvcc/wrap_nvcc.bat.in @@ -0,0 +1 @@ +python "@FAST_NVCC_EXECUTABLE@" --nvcc "@CUDA_NVCC_EXECUTABLE_ORIGIN@" -- %*