mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
USE_FAST_NVCC Windows (#95206)
USE_FAST_NVCC now works on Windows. Fixes #67100 Pull Request resolved: https://github.com/pytorch/pytorch/pull/95206 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7a192cc51c
commit
3beafc91d1
1
.gitignore
vendored
1
.gitignore
vendored
@ -64,6 +64,7 @@ third_party/build/
|
||||
tools/coverage_plugins_package/pip-wheel-metadata/
|
||||
tools/shared/_utils_internal.py
|
||||
tools/fast_nvcc/wrap_nvcc.sh
|
||||
tools/fast_nvcc/wrap_nvcc.bat
|
||||
tools/fast_nvcc/tmp/
|
||||
torch.egg-info/
|
||||
torch/_C/__init__.pyi
|
||||
|
@ -768,13 +768,17 @@ endif()
|
||||
# FAST_NVCC
|
||||
if(USE_FAST_NVCC AND CUDA_NVCC_EXECUTABLE AND NOT CUDA_NVCC_EXECUTABLE_ORIGIN)
|
||||
set(CUDA_NVCC_EXECUTABLE_ORIGIN "${CUDA_NVCC_EXECUTABLE}")
|
||||
set(EXTENSION "sh")
|
||||
if (MSVC)
|
||||
set(EXTENSION "bat")
|
||||
endif()
|
||||
set(FAST_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/fast_nvcc.py")
|
||||
configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh")
|
||||
file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh"
|
||||
configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}")
|
||||
file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}"
|
||||
DESTINATION "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/"
|
||||
FILE_PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE
|
||||
)
|
||||
set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh")
|
||||
set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}")
|
||||
endif()
|
||||
mark_as_advanced(CUDA_NVCC_EXECUTABLE)
|
||||
|
||||
|
@ -78,6 +78,8 @@ url_vars = f"{url_base}#keeping-intermediate-phase-files"
|
||||
|
||||
# regex for temporary file names
|
||||
re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)"
|
||||
if os.name == "nt":
|
||||
re_tmp = r"(?<![\w\-0])(?:\/Temp\/)?(tmp[^ \"\'\\]+)"
|
||||
|
||||
|
||||
def fast_nvcc_warn(warning: str) -> None:
|
||||
@ -141,7 +143,10 @@ def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
|
||||
print(result.stdout, end="")
|
||||
env = {}
|
||||
commands = []
|
||||
for line in result.stderr.splitlines():
|
||||
output = result.stderr
|
||||
if os.name == "nt":
|
||||
output = result.stdout
|
||||
for line in output.splitlines():
|
||||
match = re.match(r"^#\$ (.*)$", line)
|
||||
if match:
|
||||
(stripped,) = match.groups()
|
||||
@ -213,9 +218,11 @@ def unique_module_id_files(commands: List[str]) -> List[str]:
|
||||
line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line)
|
||||
if arr:
|
||||
(filename,) = arr
|
||||
if os.name == "nt":
|
||||
filename = "%TEMP%\\" + filename
|
||||
if not module_id:
|
||||
module_id = module_id_contents(shlex.split(line))
|
||||
uniqueified.append(f"echo -n '{module_id}' > '{filename}'")
|
||||
uniqueified.append(f"echo -n '{module_id}' > \"{filename}\"")
|
||||
uniqueified.append(line)
|
||||
return uniqueified
|
||||
|
||||
@ -261,6 +268,8 @@ def files_mentioned(command: str) -> List[str]:
|
||||
"""
|
||||
Return fully-qualified names of all tmp files referenced by command.
|
||||
"""
|
||||
if os.name == "nt":
|
||||
return [f"/%TEMP%/{match.group(1)}" for match in re.finditer(re_tmp, command)]
|
||||
return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)]
|
||||
|
||||
|
||||
@ -294,7 +303,9 @@ def nvcc_data_dependencies(commands: List[str]) -> Graph:
|
||||
fatbins[i].add(tmp)
|
||||
else:
|
||||
tmp_files[tmp] = i
|
||||
if line.startswith("rm ") and not deps:
|
||||
if (line.startswith("rm ") or line.startswith("erase ")) and not deps:
|
||||
if os.name == "nt":
|
||||
commands[i] = line.replace("/", "\\")
|
||||
deps.add(i - 1)
|
||||
graph.append(deps)
|
||||
return graph
|
||||
@ -421,6 +432,8 @@ async def run_graph(
|
||||
"""
|
||||
Return outputs/errors (and optionally time/file info) from commands.
|
||||
"""
|
||||
if os.name == "nt":
|
||||
env.update(os.environ.copy())
|
||||
tasks: List[Awaitable[Result]] = []
|
||||
for i, (command, indices) in enumerate(zip(commands, graph)):
|
||||
deps = {tasks[j] for j in indices}
|
||||
|
1
tools/fast_nvcc/wrap_nvcc.bat.in
Normal file
1
tools/fast_nvcc/wrap_nvcc.bat.in
Normal file
@ -0,0 +1 @@
|
||||
python "@FAST_NVCC_EXECUTABLE@" --nvcc "@CUDA_NVCC_EXECUTABLE_ORIGIN@" -- %*
|
Reference in New Issue
Block a user