mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
USE_FAST_NVCC Windows (#95206)
USE_FAST_NVCC now works on Windows. Fixes #67100 Pull Request resolved: https://github.com/pytorch/pytorch/pull/95206 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7a192cc51c
commit
3beafc91d1
1
.gitignore
vendored
1
.gitignore
vendored
@ -64,6 +64,7 @@ third_party/build/
|
|||||||
tools/coverage_plugins_package/pip-wheel-metadata/
|
tools/coverage_plugins_package/pip-wheel-metadata/
|
||||||
tools/shared/_utils_internal.py
|
tools/shared/_utils_internal.py
|
||||||
tools/fast_nvcc/wrap_nvcc.sh
|
tools/fast_nvcc/wrap_nvcc.sh
|
||||||
|
tools/fast_nvcc/wrap_nvcc.bat
|
||||||
tools/fast_nvcc/tmp/
|
tools/fast_nvcc/tmp/
|
||||||
torch.egg-info/
|
torch.egg-info/
|
||||||
torch/_C/__init__.pyi
|
torch/_C/__init__.pyi
|
||||||
|
@ -768,13 +768,17 @@ endif()
|
|||||||
# FAST_NVCC
|
# FAST_NVCC
|
||||||
if(USE_FAST_NVCC AND CUDA_NVCC_EXECUTABLE AND NOT CUDA_NVCC_EXECUTABLE_ORIGIN)
|
if(USE_FAST_NVCC AND CUDA_NVCC_EXECUTABLE AND NOT CUDA_NVCC_EXECUTABLE_ORIGIN)
|
||||||
set(CUDA_NVCC_EXECUTABLE_ORIGIN "${CUDA_NVCC_EXECUTABLE}")
|
set(CUDA_NVCC_EXECUTABLE_ORIGIN "${CUDA_NVCC_EXECUTABLE}")
|
||||||
|
set(EXTENSION "sh")
|
||||||
|
if (MSVC)
|
||||||
|
set(EXTENSION "bat")
|
||||||
|
endif()
|
||||||
set(FAST_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/fast_nvcc.py")
|
set(FAST_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/fast_nvcc.py")
|
||||||
configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh")
|
configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}")
|
||||||
file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh"
|
file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.${EXTENSION}"
|
||||||
DESTINATION "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/"
|
DESTINATION "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/"
|
||||||
FILE_PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE
|
FILE_PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE
|
||||||
)
|
)
|
||||||
set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh")
|
set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.${EXTENSION}")
|
||||||
endif()
|
endif()
|
||||||
mark_as_advanced(CUDA_NVCC_EXECUTABLE)
|
mark_as_advanced(CUDA_NVCC_EXECUTABLE)
|
||||||
|
|
||||||
|
@ -78,6 +78,8 @@ url_vars = f"{url_base}#keeping-intermediate-phase-files"
|
|||||||
|
|
||||||
# regex for temporary file names
|
# regex for temporary file names
|
||||||
re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)"
|
re_tmp = r"(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)"
|
||||||
|
if os.name == "nt":
|
||||||
|
re_tmp = r"(?<![\w\-0])(?:\/Temp\/)?(tmp[^ \"\'\\]+)"
|
||||||
|
|
||||||
|
|
||||||
def fast_nvcc_warn(warning: str) -> None:
|
def fast_nvcc_warn(warning: str) -> None:
|
||||||
@ -141,7 +143,10 @@ def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
|
|||||||
print(result.stdout, end="")
|
print(result.stdout, end="")
|
||||||
env = {}
|
env = {}
|
||||||
commands = []
|
commands = []
|
||||||
for line in result.stderr.splitlines():
|
output = result.stderr
|
||||||
|
if os.name == "nt":
|
||||||
|
output = result.stdout
|
||||||
|
for line in output.splitlines():
|
||||||
match = re.match(r"^#\$ (.*)$", line)
|
match = re.match(r"^#\$ (.*)$", line)
|
||||||
if match:
|
if match:
|
||||||
(stripped,) = match.groups()
|
(stripped,) = match.groups()
|
||||||
@ -213,9 +218,11 @@ def unique_module_id_files(commands: List[str]) -> List[str]:
|
|||||||
line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line)
|
line = re.sub(r"\s*\-\-gen\_module\_id\_file\s*", " ", line)
|
||||||
if arr:
|
if arr:
|
||||||
(filename,) = arr
|
(filename,) = arr
|
||||||
|
if os.name == "nt":
|
||||||
|
filename = "%TEMP%\\" + filename
|
||||||
if not module_id:
|
if not module_id:
|
||||||
module_id = module_id_contents(shlex.split(line))
|
module_id = module_id_contents(shlex.split(line))
|
||||||
uniqueified.append(f"echo -n '{module_id}' > '{filename}'")
|
uniqueified.append(f"echo -n '{module_id}' > \"{filename}\"")
|
||||||
uniqueified.append(line)
|
uniqueified.append(line)
|
||||||
return uniqueified
|
return uniqueified
|
||||||
|
|
||||||
@ -261,6 +268,8 @@ def files_mentioned(command: str) -> List[str]:
|
|||||||
"""
|
"""
|
||||||
Return fully-qualified names of all tmp files referenced by command.
|
Return fully-qualified names of all tmp files referenced by command.
|
||||||
"""
|
"""
|
||||||
|
if os.name == "nt":
|
||||||
|
return [f"/%TEMP%/{match.group(1)}" for match in re.finditer(re_tmp, command)]
|
||||||
return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)]
|
return [f"/tmp/{match.group(1)}" for match in re.finditer(re_tmp, command)]
|
||||||
|
|
||||||
|
|
||||||
@ -294,7 +303,9 @@ def nvcc_data_dependencies(commands: List[str]) -> Graph:
|
|||||||
fatbins[i].add(tmp)
|
fatbins[i].add(tmp)
|
||||||
else:
|
else:
|
||||||
tmp_files[tmp] = i
|
tmp_files[tmp] = i
|
||||||
if line.startswith("rm ") and not deps:
|
if (line.startswith("rm ") or line.startswith("erase ")) and not deps:
|
||||||
|
if os.name == "nt":
|
||||||
|
commands[i] = line.replace("/", "\\")
|
||||||
deps.add(i - 1)
|
deps.add(i - 1)
|
||||||
graph.append(deps)
|
graph.append(deps)
|
||||||
return graph
|
return graph
|
||||||
@ -421,6 +432,8 @@ async def run_graph(
|
|||||||
"""
|
"""
|
||||||
Return outputs/errors (and optionally time/file info) from commands.
|
Return outputs/errors (and optionally time/file info) from commands.
|
||||||
"""
|
"""
|
||||||
|
if os.name == "nt":
|
||||||
|
env.update(os.environ.copy())
|
||||||
tasks: List[Awaitable[Result]] = []
|
tasks: List[Awaitable[Result]] = []
|
||||||
for i, (command, indices) in enumerate(zip(commands, graph)):
|
for i, (command, indices) in enumerate(zip(commands, graph)):
|
||||||
deps = {tasks[j] for j in indices}
|
deps = {tasks[j] for j in indices}
|
||||||
|
1
tools/fast_nvcc/wrap_nvcc.bat.in
Normal file
1
tools/fast_nvcc/wrap_nvcc.bat.in
Normal file
@ -0,0 +1 @@
|
|||||||
|
python "@FAST_NVCC_EXECUTABLE@" --nvcc "@CUDA_NVCC_EXECUTABLE_ORIGIN@" -- %*
|
Reference in New Issue
Block a user