Add torch.compile support to minifier (#90308)

Initial fix for https://github.com/pytorch/torchdynamo/issues/1964.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90308
Approved by: https://github.com/mlazos
This commit is contained in:
William Wen
2022-12-14 18:24:42 +00:00
committed by PyTorch MergeBot
parent fde5646f3d
commit e9dc8cc19b
3 changed files with 69 additions and 12 deletions

View File

@ -204,6 +204,54 @@ torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}"
def test_after_aot_cuda_accuracy_backend_passes(self):
self._test_after_aot_backend_passes("cuda", 4, TRITON_ACCURACY_ERROR)
def _test_torch_compile(self, repro_after, repro_level, backend_code):
run_code = textwrap.dedent(
"""\
def inner(x):
for _ in range(3):
x = torch.sin(x)
x = torch.relu(x)
for _ in range(3):
x = torch.cos(x)
return x
inner_opt = torch.compile(inner)
inner_opt(torch.randn(20, 20))
"""
)
patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, "cpu")
self.assertIsNotNone(patch_code)
(test_proc, _, repro_proc), _ = self._run_full_test(
run_code, repro_after, repro_level, patch_code
)
return (
(test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")),
(test_proc.returncode, repro_proc.returncode),
)
def test_torch_compile_after_dynamo_compile_error(self):
(tb1, tb2), _ = self._test_torch_compile("dynamo", 2, CPP_COMPILE_ERROR)
self.assertIn("CppCompileError", tb1)
self.assertIn("CppCompileError", tb2)
def test_torch_compile_after_dynamo_accuracy_error(self):
(tb1, tb2), _ = self._test_torch_compile("dynamo", 4, CPP_ACCURACY_ERROR)
self.assertIn("AccuracyError", tb1)
self.assertIn("AccuracyError", tb2)
def test_torch_compile_after_aot_compile_error(self):
(tb1, tb2), _ = self._test_torch_compile("aot", 2, CPP_COMPILE_ERROR)
self.assertIn("CppCompileError", tb1)
self.assertIn("CppCompileError", tb2)
def test_torch_compile_after_aot_accuracy_error(self):
(tb1, tb2), _ = self._test_torch_compile("aot", 4, CPP_ACCURACY_ERROR)
self.assertIn("AccuracyError", tb1)
self.assertIn("AccuracyError", tb2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1139,6 +1139,20 @@ from ._linalg_utils import ( # type: ignore[misc]
lstsq,
)
class _TorchCompileInductorWrapper:
def __init__(self, mode, passes):
from torch._dynamo.eval_frame import lookup_backend
from torch._inductor.config import InductorConfigContext
self.compile_fn = lookup_backend("inductor")
self.cm = InductorConfigContext(mode if mode is not None else passes)
self._torchdynamo_orig_callable = self.compile_fn
def __call__(self, model_, inputs_):
with self.cm:
return self.compile_fn(model_, inputs_)
def compile(model: Optional[Callable] = None, *,
fullgraph: builtins.bool = False,
dynamic: builtins.bool = False,
@ -1189,22 +1203,12 @@ def compile(model: Optional[Callable] = None, *,
return fn
import torch._dynamo
from torch._dynamo.eval_frame import lookup_backend
from torch._inductor.config import InductorConfigContext
if mode is not None and passes is not None:
raise RuntimeError("Either mode or passes can be specified, but both can't be specified at the same time.")
if mode is None and passes is None:
mode = "default"
if backend == "inductor":
compile_fn = lookup_backend(backend)
cm = InductorConfigContext(mode if mode is not None else passes)
def _compile_fn(model_, inputs_):
with cm:
return compile_fn(model_, inputs_)
_compile_fn._torchdynamo_orig_callable = compile_fn # type: ignore[attr-defined]
backend = _compile_fn
backend = _TorchCompileInductorWrapper(mode, passes)
return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, **kwargs)(model)

View File

@ -349,7 +349,12 @@ def _optimize_catch_errors(
def get_compiler_fn(compiler_fn):
from .debug_utils import wrap_backend_debug
compiler_str = compiler_fn if isinstance(compiler_fn, str) else None
if isinstance(compiler_fn, torch._TorchCompileInductorWrapper):
compiler_str = "inductor"
elif isinstance(compiler_fn, str):
compiler_str = compiler_fn
else:
compiler_str = None
compiler_fn = lookup_backend(compiler_fn)
return wrap_backend_debug(compiler_fn, compiler_str)