USE_FAST_NVCC Windows (#95206)

USE_FAST_NVCC now works on Windows.

Fixes #67100

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95206
Approved by: https://github.com/ezyang
This commit is contained in:
mantaionut
2023-03-06 15:04:24 +00:00
committed by PyTorch MergeBot
parent 7a192cc51c
commit 3beafc91d1
4 changed files with 25 additions and 6 deletions

1
.gitignore vendored
View File

@ -64,6 +64,7 @@ third_party/build/
tools/coverage_plugins_package/pip-wheel-metadata/
tools/shared/_utils_internal.py
tools/fast_nvcc/wrap_nvcc.sh
tools/fast_nvcc/wrap_nvcc.bat
tools/fast_nvcc/tmp/
torch.egg-info/
torch/_C/__init__.pyi

View File

@ -768,13 +768,17 @@ endif()
# FAST_NVCC
if(USE_FAST_NVCC AND CUDA_NVCC_EXECUTABLE AND NOT CUDA_NVCC_EXECUTABLE_ORIGIN)
set(CUDA_NVCC_EXECUTABLE_ORIGIN "${CUDA_NVCC_EXECUTABLE}")
set(EXTENSION "sh")
if (MSVC)
set(EXTENSION "bat")
endif()
set(FAST_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/fast_nvcc.py")
configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh")
file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh"
configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}")
file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}"
DESTINATION "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/"
FILE_PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE
)
set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh")
set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}")
endif()
mark_as_advanced(CUDA_NVCC_EXECUTABLE)

View File

@ -78,6 +78,8 @@ url_vars = f"{url_base}#keeping-intermediate-phase-files"
# regex for temporary file names
re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)"
if os.name == "nt":
re_tmp = r"(?<![\w\-0])(?:\/Temp\/)?(tmp[^ \"\'\\]+)"
def fast_nvcc_warn(warning: str) -> None:
@ -141,7 +143,10 @@ def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
print(result.stdout, end="")
env = {}
commands = []
for line in result.stderr.splitlines():
output = result.stderr
if os.name == "nt":
output = result.stdout
for line in output.splitlines():
match = re.match(r"^#\$ (.*)$", line)
if match:
(stripped,) = match.groups()
@ -213,9 +218,11 @@ def unique_module_id_files(commands: List[str]) -> List[str]:
line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line)
if arr:
(filename,) = arr
if os.name == "nt":
filename = "%TEMP%\\" + filename
if not module_id:
module_id = module_id_contents(shlex.split(line))
uniqueified.append(f"echo -n '{module_id}' > '{filename}'")
uniqueified.append(f"echo -n '{module_id}' > \"{filename}\"")
uniqueified.append(line)
return uniqueified
@ -261,6 +268,8 @@ def files_mentioned(command: str) -> List[str]:
"""
Return fully-qualified names of all tmp files referenced by command.
"""
if os.name == "nt":
return [f"/%TEMP%/{match.group(1)}" for match in re.finditer(re_tmp, command)]
return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)]
@ -294,7 +303,9 @@ def nvcc_data_dependencies(commands: List[str]) -> Graph:
fatbins[i].add(tmp)
else:
tmp_files[tmp] = i
if line.startswith("rm ") and not deps:
if (line.startswith("rm ") or line.startswith("erase ")) and not deps:
if os.name == "nt":
commands[i] = line.replace("/", "\\")
deps.add(i - 1)
graph.append(deps)
return graph
@ -421,6 +432,8 @@ async def run_graph(
"""
Return outputs/errors (and optionally time/file info) from commands.
"""
if os.name == "nt":
env.update(os.environ.copy())
tasks: List[Awaitable[Result]] = []
for i, (command, indices) in enumerate(zip(commands, graph)):
deps = {tasks[j] for j in indices}

View File

@ -0,0 +1 @@
python "@FAST_NVCC_EXECUTABLE@" --nvcc "@CUDA_NVCC_EXECUTABLE_ORIGIN@" -- %*