mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch.compile support to minifier (#90308)
Initial fix for https://github.com/pytorch/torchdynamo/issues/1964. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90308 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
fde5646f3d
commit
e9dc8cc19b
@ -204,6 +204,54 @@ torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}"
|
||||
def test_after_aot_cuda_accuracy_backend_passes(self):
|
||||
self._test_after_aot_backend_passes("cuda", 4, TRITON_ACCURACY_ERROR)
|
||||
|
||||
def _test_torch_compile(self, repro_after, repro_level, backend_code):
|
||||
run_code = textwrap.dedent(
|
||||
"""\
|
||||
def inner(x):
|
||||
for _ in range(3):
|
||||
x = torch.sin(x)
|
||||
x = torch.relu(x)
|
||||
for _ in range(3):
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
inner_opt = torch.compile(inner)
|
||||
|
||||
inner_opt(torch.randn(20, 20))
|
||||
"""
|
||||
)
|
||||
|
||||
patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, "cpu")
|
||||
self.assertIsNotNone(patch_code)
|
||||
|
||||
(test_proc, _, repro_proc), _ = self._run_full_test(
|
||||
run_code, repro_after, repro_level, patch_code
|
||||
)
|
||||
return (
|
||||
(test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")),
|
||||
(test_proc.returncode, repro_proc.returncode),
|
||||
)
|
||||
|
||||
def test_torch_compile_after_dynamo_compile_error(self):
|
||||
(tb1, tb2), _ = self._test_torch_compile("dynamo", 2, CPP_COMPILE_ERROR)
|
||||
self.assertIn("CppCompileError", tb1)
|
||||
self.assertIn("CppCompileError", tb2)
|
||||
|
||||
def test_torch_compile_after_dynamo_accuracy_error(self):
|
||||
(tb1, tb2), _ = self._test_torch_compile("dynamo", 4, CPP_ACCURACY_ERROR)
|
||||
self.assertIn("AccuracyError", tb1)
|
||||
self.assertIn("AccuracyError", tb2)
|
||||
|
||||
def test_torch_compile_after_aot_compile_error(self):
|
||||
(tb1, tb2), _ = self._test_torch_compile("aot", 2, CPP_COMPILE_ERROR)
|
||||
self.assertIn("CppCompileError", tb1)
|
||||
self.assertIn("CppCompileError", tb2)
|
||||
|
||||
def test_torch_compile_after_aot_accuracy_error(self):
|
||||
(tb1, tb2), _ = self._test_torch_compile("aot", 4, CPP_ACCURACY_ERROR)
|
||||
self.assertIn("AccuracyError", tb1)
|
||||
self.assertIn("AccuracyError", tb2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -1139,6 +1139,20 @@ from ._linalg_utils import ( # type: ignore[misc]
|
||||
lstsq,
|
||||
)
|
||||
|
||||
class _TorchCompileInductorWrapper:
|
||||
def __init__(self, mode, passes):
|
||||
from torch._dynamo.eval_frame import lookup_backend
|
||||
from torch._inductor.config import InductorConfigContext
|
||||
|
||||
self.compile_fn = lookup_backend("inductor")
|
||||
self.cm = InductorConfigContext(mode if mode is not None else passes)
|
||||
self._torchdynamo_orig_callable = self.compile_fn
|
||||
|
||||
def __call__(self, model_, inputs_):
|
||||
with self.cm:
|
||||
return self.compile_fn(model_, inputs_)
|
||||
|
||||
|
||||
def compile(model: Optional[Callable] = None, *,
|
||||
fullgraph: builtins.bool = False,
|
||||
dynamic: builtins.bool = False,
|
||||
@ -1189,22 +1203,12 @@ def compile(model: Optional[Callable] = None, *,
|
||||
return fn
|
||||
|
||||
import torch._dynamo
|
||||
from torch._dynamo.eval_frame import lookup_backend
|
||||
from torch._inductor.config import InductorConfigContext
|
||||
if mode is not None and passes is not None:
|
||||
raise RuntimeError("Either mode or passes can be specified, but both can't be specified at the same time.")
|
||||
if mode is None and passes is None:
|
||||
mode = "default"
|
||||
if backend == "inductor":
|
||||
compile_fn = lookup_backend(backend)
|
||||
cm = InductorConfigContext(mode if mode is not None else passes)
|
||||
|
||||
def _compile_fn(model_, inputs_):
|
||||
with cm:
|
||||
return compile_fn(model_, inputs_)
|
||||
|
||||
_compile_fn._torchdynamo_orig_callable = compile_fn # type: ignore[attr-defined]
|
||||
backend = _compile_fn
|
||||
backend = _TorchCompileInductorWrapper(mode, passes)
|
||||
return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, **kwargs)(model)
|
||||
|
||||
|
||||
|
@ -349,7 +349,12 @@ def _optimize_catch_errors(
|
||||
def get_compiler_fn(compiler_fn):
|
||||
from .debug_utils import wrap_backend_debug
|
||||
|
||||
compiler_str = compiler_fn if isinstance(compiler_fn, str) else None
|
||||
if isinstance(compiler_fn, torch._TorchCompileInductorWrapper):
|
||||
compiler_str = "inductor"
|
||||
elif isinstance(compiler_fn, str):
|
||||
compiler_str = compiler_fn
|
||||
else:
|
||||
compiler_str = None
|
||||
compiler_fn = lookup_backend(compiler_fn)
|
||||
return wrap_backend_debug(compiler_fn, compiler_str)
|
||||
|
||||
|
Reference in New Issue
Block a user