USE_FAST_NVCC Windows (#95206)

USE_FAST_NVCC now works on Windows.

Fixes #67100

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95206
Approved by: https://github.com/ezyang
This commit is contained in:
mantaionut
2023-03-06 15:04:24 +00:00
committed by PyTorch MergeBot
parent 7a192cc51c
commit 3beafc91d1
4 changed files with 25 additions and 6 deletions

1
.gitignore vendored
View File

@ -64,6 +64,7 @@ third_party/build/
tools/coverage_plugins_package/pip-wheel-metadata/ tools/coverage_plugins_package/pip-wheel-metadata/
tools/shared/_utils_internal.py tools/shared/_utils_internal.py
tools/fast_nvcc/wrap_nvcc.sh tools/fast_nvcc/wrap_nvcc.sh
tools/fast_nvcc/wrap_nvcc.bat
tools/fast_nvcc/tmp/ tools/fast_nvcc/tmp/
torch.egg-info/ torch.egg-info/
torch/_C/__init__.pyi torch/_C/__init__.pyi

View File

@ -768,13 +768,17 @@ endif()
# FAST_NVCC # FAST_NVCC
if(USE_FAST_NVCC AND CUDA_NVCC_EXECUTABLE AND NOT CUDA_NVCC_EXECUTABLE_ORIGIN) if(USE_FAST_NVCC AND CUDA_NVCC_EXECUTABLE AND NOT CUDA_NVCC_EXECUTABLE_ORIGIN)
set(CUDA_NVCC_EXECUTABLE_ORIGIN "${CUDA_NVCC_EXECUTABLE}") set(CUDA_NVCC_EXECUTABLE_ORIGIN "${CUDA_NVCC_EXECUTABLE}")
set(EXTENSION "sh")
if (MSVC)
set(EXTENSION "bat")
endif()
set(FAST_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/fast_nvcc.py") set(FAST_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/fast_nvcc.py")
configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh") configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}")
file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh" file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}"
DESTINATION "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/" DESTINATION "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/"
FILE_PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE FILE_PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE
) )
set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh") set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}")
endif() endif()
mark_as_advanced(CUDA_NVCC_EXECUTABLE) mark_as_advanced(CUDA_NVCC_EXECUTABLE)

View File

@ -78,6 +78,8 @@ url_vars = f"{url_base}#keeping-intermediate-phase-files"
# regex for temporary file names # regex for temporary file names
re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)" re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)"
if os.name == "nt":
re_tmp = r"(?<![\w\-0])(?:\/Temp\/)?(tmp[^ \"\'\\]+)"
def fast_nvcc_warn(warning: str) -> None: def fast_nvcc_warn(warning: str) -> None:
@ -141,7 +143,10 @@ def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
print(result.stdout, end="") print(result.stdout, end="")
env = {} env = {}
commands = [] commands = []
for line in result.stderr.splitlines(): output = result.stderr
if os.name == "nt":
output = result.stdout
for line in output.splitlines():
match = re.match(r"^#\$ (.*)$", line) match = re.match(r"^#\$ (.*)$", line)
if match: if match:
(stripped,) = match.groups() (stripped,) = match.groups()
@ -213,9 +218,11 @@ def unique_module_id_files(commands: List[str]) -> List[str]:
line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line) line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line)
if arr: if arr:
(filename,) = arr (filename,) = arr
if os.name == "nt":
filename = "%TEMP%\\" + filename
if not module_id: if not module_id:
module_id = module_id_contents(shlex.split(line)) module_id = module_id_contents(shlex.split(line))
uniqueified.append(f"echo -n '{module_id}' > '{filename}'") uniqueified.append(f"echo -n '{module_id}' > \"{filename}\"")
uniqueified.append(line) uniqueified.append(line)
return uniqueified return uniqueified
@ -261,6 +268,8 @@ def files_mentioned(command: str) -> List[str]:
""" """
Return fully-qualified names of all tmp files referenced by command. Return fully-qualified names of all tmp files referenced by command.
""" """
if os.name == "nt":
return [f"/%TEMP%/{match.group(1)}" for match in re.finditer(re_tmp, command)]
return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)] return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)]
@ -294,7 +303,9 @@ def nvcc_data_dependencies(commands: List[str]) -> Graph:
fatbins[i].add(tmp) fatbins[i].add(tmp)
else: else:
tmp_files[tmp] = i tmp_files[tmp] = i
if line.startswith("rm ") and not deps: if (line.startswith("rm ") or line.startswith("erase ")) and not deps:
if os.name == "nt":
commands[i] = line.replace("/", "\\")
deps.add(i - 1) deps.add(i - 1)
graph.append(deps) graph.append(deps)
return graph return graph
@ -421,6 +432,8 @@ async def run_graph(
""" """
Return outputs/errors (and optionally time/file info) from commands. Return outputs/errors (and optionally time/file info) from commands.
""" """
if os.name == "nt":
env.update(os.environ.copy())
tasks: List[Awaitable[Result]] = [] tasks: List[Awaitable[Result]] = []
for i, (command, indices) in enumerate(zip(commands, graph)): for i, (command, indices) in enumerate(zip(commands, graph)):
deps = {tasks[j] for j in indices} deps = {tasks[j] for j in indices}

View File

@ -0,0 +1 @@
python "@FAST_NVCC_EXECUTABLE@" --nvcc "@CUDA_NVCC_EXECUTABLE_ORIGIN@" -- %*