mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable ninja during python build process for MSVC (#3993)
This commit is contained in:
2
setup.py
2
setup.py
@ -306,6 +306,8 @@ class build_ext(build_ext_parent):
|
||||
else:
|
||||
# To generate .obj files for AutoGPU for the export class
|
||||
# a header file cannot build, so it has to be copied to someplace as a source file
|
||||
if not os.path.exists("torch/csrc/generated"):
|
||||
os.mkdir("torch/csrc/generated")
|
||||
if os.path.exists("torch/csrc/generated/AutoGPU_cpu_win.cpp"):
|
||||
os.remove("torch/csrc/generated/AutoGPU_cpu_win.cpp")
|
||||
shutil.copyfile("torch/csrc/cuda/AutoGPU.h", "torch/csrc/generated/AutoGPU_cpu_win.cpp")
|
||||
|
@ -1,3 +1,4 @@
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
import setuptools
|
||||
@ -61,10 +62,56 @@ class ninja_build_ext(setuptools.command.build_ext.build_ext):
|
||||
finally:
|
||||
setattr(obj, attr_name, orig_val)
|
||||
|
||||
orig_compile = distutils.unixccompiler.UnixCCompiler._compile
|
||||
orig_link = distutils.unixccompiler.UnixCCompiler.link
|
||||
if self.compiler.compiler_type == 'msvc':
|
||||
import distutils.msvccompiler
|
||||
import distutils.msvc9compiler
|
||||
if sys.version[0] == 2:
|
||||
orig_compiler = distutils.msvc9compiler.MSVCCompiler
|
||||
else:
|
||||
orig_compiler = distutils._msvccompiler.MSVCCompiler
|
||||
orig_compile = orig_compiler.compile
|
||||
orig_link = orig_compiler.link
|
||||
orig_spawn = orig_compiler.spawn
|
||||
else:
|
||||
orig_compiler = distutils.unixccompiler.UnixCCompiler
|
||||
orig_compile = orig_compiler._compile
|
||||
orig_link = orig_compiler.link
|
||||
|
||||
def _compile(self, obj, src, ext, cc_args, extra_postargs, pp_opts):
|
||||
def win_compile(self, sources,
|
||||
output_dir=None, macros=None, include_dirs=None, debug=0,
|
||||
extra_preargs=None, extra_postargs=None, depends=None):
|
||||
|
||||
def spawn(cmd):
|
||||
# Using regex to match src and obj
|
||||
|
||||
src_regex = re.compile('/T(p|c)(.*)')
|
||||
src_list = [m.group(2) for m in (
|
||||
src_regex.match(elem) for elem in cmd) if m]
|
||||
|
||||
obj_regex = re.compile('/Fo(.*)')
|
||||
obj_list = [m.group(1) for m in (
|
||||
obj_regex.match(elem) for elem in cmd) if m]
|
||||
|
||||
if len(src_list) >= 1 and len(obj_list) >= 1:
|
||||
src = src_list[0]
|
||||
obj = obj_list[0]
|
||||
else:
|
||||
# Cannot find src or obj, revert back to original style
|
||||
return orig_spawn(cmd)
|
||||
|
||||
builder.writer.build(
|
||||
[obj], 'compile', [src],
|
||||
variables={
|
||||
'cmd': cmd,
|
||||
'deps': 'msvc'
|
||||
})
|
||||
|
||||
with patch(self, 'spawn', spawn):
|
||||
orig_compile(self, sources,
|
||||
output_dir, macros, include_dirs, debug,
|
||||
extra_preargs, extra_postargs, depends)
|
||||
|
||||
def unix_compile(self, obj, src, ext, cc_args, extra_postargs, pp_opts):
|
||||
depfile = os.path.splitext(obj)[0] + '.d'
|
||||
|
||||
def spawn(cmd):
|
||||
@ -93,7 +140,14 @@ class ninja_build_ext(setuptools.command.build_ext.build_ext):
|
||||
export_symbols, debug, extra_preargs,
|
||||
extra_postargs, build_temp, target_lang)
|
||||
|
||||
with patch(distutils.unixccompiler.UnixCCompiler, '_compile', _compile):
|
||||
with patch(distutils.unixccompiler.UnixCCompiler, 'link', link):
|
||||
if self.compiler.compiler_type == 'msvc':
|
||||
_compile_func = win_compile
|
||||
_compile_func_name = 'compile'
|
||||
else:
|
||||
_compile_func = unix_compile
|
||||
_compile_func_name = '_compile'
|
||||
|
||||
with patch(orig_compiler, _compile_func_name, _compile_func):
|
||||
with patch(orig_compiler, 'link', link):
|
||||
with patch(self, 'force', True):
|
||||
self._build_default(ext)
|
||||
|
@ -83,4 +83,5 @@ def split_types(file_name, ninja_global):
|
||||
|
||||
# when called from ninja
|
||||
if __name__ == '__main__':
|
||||
split_types(sys.argv[1], None)
|
||||
file_name = sys.argv[1].strip("'")
|
||||
split_types(file_name, None)
|
||||
|
Reference in New Issue
Block a user