mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes https://github.com/pytorch/torchdynamo/issues/1965 Pull Request resolved: https://github.com/pytorch/pytorch/pull/90501 Approved by: https://github.com/mlazos
317 lines
11 KiB
Python
317 lines
11 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import functools
|
|
import textwrap
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._inductor.utils
|
|
from torch._dynamo.test_minifier_common import MinifierTestBase
|
|
from torch.testing._internal.common_utils import IS_MACOS
|
|
|
|
_HAS_TRITON = torch._inductor.utils.has_triton()
|
|
requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cuda")
|
|
|
|
CPP_COMPILE_ERROR = """\
|
|
def cpp_compile_error(x):
|
|
return "compile error!"
|
|
"""
|
|
|
|
CPP_RUNTIME_ERROR = """\
|
|
def cpp_runtime_error(x):
|
|
return f"{x}; throw 1"
|
|
"""
|
|
|
|
CPP_ACCURACY_ERROR = """\
|
|
def cpp_accuracy_error(x):
|
|
return f"{x} + decltype({x})(1)"
|
|
"""
|
|
|
|
TRITON_COMPILE_ERROR = """\
|
|
def triton_compile_error(x):
|
|
return "compile error!"
|
|
"""
|
|
|
|
# NOTE: there is currently not an easy way to cause a triton runtime error.
|
|
TRITON_RUNTIME_ERROR = """\
|
|
def triton_runtime_error(x):
|
|
return f"{x}; assert?"
|
|
"""
|
|
|
|
TRITON_ACCURACY_ERROR = """\
|
|
def triton_accuracy_error(x):
|
|
return f"{x} + 1"
|
|
"""
|
|
|
|
|
|
class MinifierTests(MinifierTestBase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super().tearDownClass()
|
|
|
|
# Generates code that patches CppOverrides/TritonOverrides.
|
|
def _gen_codegen_fn_patch_code(self, old_fn_name, new_fn_code, device):
|
|
new_fn_name = self._get_fn_name(new_fn_code)
|
|
if new_fn_name is not None:
|
|
patch_code = f"""\
|
|
import torch._inductor.codegen.{"cpp" if device == "cpu" else "triton"} as codegen
|
|
overrides = codegen.{"CppOverrides" if device == "cpu" else "TritonOverrides"}
|
|
vec_overrides = codegen.{"CppVecOverrides" if device == "cpu" else "TritonOverrides"}
|
|
{new_fn_code}
|
|
overrides.{old_fn_name} = staticmethod({new_fn_name})
|
|
vec_overrides.{old_fn_name} = staticmethod({new_fn_name})
|
|
"""
|
|
return f"""\
|
|
{patch_code}
|
|
isolate_fails_code_str = \"\"\"\\
|
|
{patch_code}
|
|
torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}"
|
|
\"\"\"
|
|
"""
|
|
|
|
# Test that compile and accuracy errors after aot can be repro'd (both CPU and CUDA)
|
|
def _test_after_aot(self, device, backend_code, repro_level):
|
|
run_code = textwrap.dedent(
|
|
f"""\
|
|
@torch._dynamo.optimize("inductor")
|
|
def inner(x):
|
|
for _ in range(3):
|
|
x = torch.sin(x)
|
|
x = torch.relu(x)
|
|
for _ in range(3):
|
|
x = torch.cos(x)
|
|
return x
|
|
|
|
inner(torch.randn(20, 20).to("{device}"))
|
|
"""
|
|
)
|
|
patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, device)
|
|
self.assertIsNotNone(patch_code)
|
|
(test_proc, _, repro_proc), _ = self._run_full_test(
|
|
run_code, "aot", repro_level, patch_code
|
|
)
|
|
return (
|
|
(test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")),
|
|
(test_proc.returncode, repro_proc.returncode),
|
|
)
|
|
|
|
def test_after_aot_cpu_compile_error(self):
|
|
(tb1, tb2), _ = self._test_after_aot("cpu", CPP_COMPILE_ERROR, 2)
|
|
self.assertIn("CppCompileError", tb1)
|
|
self.assertIn("CppCompileError", tb2)
|
|
|
|
def test_after_aot_cpu_accuracy_error(self):
|
|
(tb1, tb2), _ = self._test_after_aot("cpu", CPP_ACCURACY_ERROR, 4)
|
|
self.assertIn("AccuracyError", tb1)
|
|
self.assertIn("AccuracyError", tb2)
|
|
|
|
@requires_cuda()
|
|
def test_after_aot_cuda_compile_error(self):
|
|
(tb1, tb2), _ = self._test_after_aot("cuda", TRITON_COMPILE_ERROR, 2)
|
|
self.assertIn("SyntaxError", tb1)
|
|
self.assertIn("SyntaxError", tb2)
|
|
|
|
@requires_cuda()
|
|
def test_after_aot_cuda_accuracy_error(self):
|
|
(tb1, tb2), _ = self._test_after_aot("cuda", TRITON_ACCURACY_ERROR, 4)
|
|
self.assertIn("AccuracyError", tb1)
|
|
self.assertIn("AccuracyError", tb2)
|
|
|
|
# Test that runtime errors after aot can be repro'd (CPU only for now)
|
|
def _test_after_aot_runtime_error(self, device, backend_code):
|
|
run_code = textwrap.dedent(
|
|
f"""\
|
|
@torch._dynamo.optimize("inductor")
|
|
def inner(x):
|
|
for _ in range(3):
|
|
x = torch.sin(x)
|
|
x = torch.relu(x)
|
|
for _ in range(3):
|
|
x = torch.cos(x)
|
|
return x
|
|
|
|
inner(torch.randn(20, 20).to("{device}"))
|
|
"""
|
|
)
|
|
patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, device)
|
|
self.assertIsNotNone(patch_code)
|
|
|
|
(test_proc, _, repro_proc), _ = self._run_full_test(
|
|
run_code, "aot", 3, patch_code
|
|
)
|
|
|
|
self.assertNotIn("CompilerError", test_proc.stderr.decode("utf-8"))
|
|
|
|
self.assertEqual(test_proc.returncode, repro_proc.returncode)
|
|
self.assertNotEqual(test_proc.returncode, 0)
|
|
|
|
def test_after_aot_cpu_runtime_error(self):
|
|
self._test_after_aot_runtime_error("cpu", CPP_RUNTIME_ERROR)
|
|
|
|
# NOTE: there is currently not an easy way to cause a triton runtime error.
|
|
@unittest.skip
|
|
@requires_cuda()
|
|
def test_after_aot_cuda_runtime_error(self):
|
|
self._test_after_aot_runtime_error("cuda", TRITON_RUNTIME_ERROR)
|
|
|
|
# Ensure that inductor codegen patches pass when relu is not present.
|
|
def _test_after_aot_backend_passes(self, device, repro_level, backend_code):
|
|
run_code = textwrap.dedent(
|
|
f"""\
|
|
@torch._dynamo.optimize("inductor")
|
|
def inner(x):
|
|
for _ in range(3):
|
|
x = torch.sin(x)
|
|
for _ in range(3):
|
|
x = torch.cos(x)
|
|
return x
|
|
|
|
inner(torch.randn(20, 20).to("{device}"))
|
|
"""
|
|
)
|
|
patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, device)
|
|
self.assertIsNotNone(patch_code)
|
|
|
|
test_code = self._gen_test_code(run_code, "aot", repro_level, patch_code)
|
|
proc, repro_dir = self._run_test_code(test_code)
|
|
self.assertEqual(proc.returncode, 0)
|
|
self.assertIsNone(repro_dir)
|
|
|
|
def test_after_aot_cpu_compile_backend_passes(self):
|
|
self._test_after_aot_backend_passes("cpu", 2, CPP_COMPILE_ERROR)
|
|
|
|
def test_after_aot_cpu_runtime_backend_passes(self):
|
|
self._test_after_aot_backend_passes("cpu", 2, CPP_RUNTIME_ERROR)
|
|
|
|
def test_after_aot_cpu_accuracy_backend_passes(self):
|
|
self._test_after_aot_backend_passes("cpu", 4, CPP_ACCURACY_ERROR)
|
|
|
|
@requires_cuda()
|
|
def test_after_aot_cuda_compile_backend_passes(self):
|
|
self._test_after_aot_backend_passes("cuda", 2, TRITON_COMPILE_ERROR)
|
|
|
|
# NOTE: there is currently not an easy way to cause a triton runtime error.
|
|
@unittest.skip
|
|
@requires_cuda()
|
|
def test_after_aot_cuda_runtime_backend_passes(self):
|
|
self._test_after_aot_backend_passes("cuda", 2, TRITON_RUNTIME_ERROR)
|
|
|
|
@requires_cuda()
|
|
def test_after_aot_cuda_accuracy_backend_passes(self):
|
|
self._test_after_aot_backend_passes("cuda", 4, TRITON_ACCURACY_ERROR)
|
|
|
|
# Test that inductor config can be saved and restored, especially class
|
|
# variables.
|
|
def test_inductor_config_serialization(self):
|
|
run_code = textwrap.dedent(
|
|
"""\
|
|
import torch._inductor.config
|
|
torch._inductor.config.cpp.threads = 5
|
|
data = torch._inductor.config.save_config()
|
|
torch._inductor.config.cpp.threads = 10
|
|
torch._inductor.config.load_config(data)
|
|
assert torch._inductor.config.cpp.threads == 5
|
|
"""
|
|
)
|
|
proc, _ = self._run_test_code(run_code)
|
|
self.assertEqual(proc.returncode, 0)
|
|
|
|
# Test that launched minifier processes have the same config as
|
|
# the original process.
|
|
def _test_after_aot_with_modified_config(self, backend_code, repro_level):
|
|
lines = backend_code.split("\n")
|
|
lines.insert(1, " assert torch._inductor.config.cpp.threads == 10")
|
|
backend_code = "\n".join(lines)
|
|
run_code = textwrap.dedent(
|
|
"""\
|
|
torch._inductor.config.cpp.threads = 10
|
|
@torch._dynamo.optimize("inductor")
|
|
def inner(x):
|
|
for _ in range(3):
|
|
x = torch.sin(x)
|
|
x = torch.relu(x)
|
|
for _ in range(3):
|
|
x = torch.cos(x)
|
|
return x
|
|
|
|
inner(torch.randn(20, 20).to("cpu"))
|
|
"""
|
|
)
|
|
patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, "cpu")
|
|
self.assertIsNotNone(patch_code)
|
|
(test_proc, _, repro_proc), _ = self._run_full_test(
|
|
run_code, "aot", repro_level, patch_code
|
|
)
|
|
return (test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8"))
|
|
|
|
def test_after_aot_with_modified_config_compile_error(self):
|
|
tb1, tb2 = self._test_after_aot_with_modified_config(CPP_COMPILE_ERROR, 2)
|
|
self.assertIn("CppCompileError", tb1)
|
|
self.assertIn("CppCompileError", tb2)
|
|
|
|
def test_after_aot_with_modified_config_accuracy_error(self):
|
|
tb1, tb2 = self._test_after_aot_with_modified_config(CPP_ACCURACY_ERROR, 4)
|
|
self.assertIn("AccuracyError", tb1)
|
|
self.assertIn("AccuracyError", tb2)
|
|
|
|
# Test that default torch.compile can be minified.
|
|
def _test_torch_compile(self, repro_after, repro_level, backend_code):
|
|
run_code = textwrap.dedent(
|
|
"""\
|
|
def inner(x):
|
|
for _ in range(3):
|
|
x = torch.sin(x)
|
|
x = torch.relu(x)
|
|
for _ in range(3):
|
|
x = torch.cos(x)
|
|
return x
|
|
|
|
inner_opt = torch.compile(inner)
|
|
|
|
inner_opt(torch.randn(20, 20))
|
|
"""
|
|
)
|
|
|
|
patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, "cpu")
|
|
self.assertIsNotNone(patch_code)
|
|
|
|
(test_proc, _, repro_proc), _ = self._run_full_test(
|
|
run_code, repro_after, repro_level, patch_code
|
|
)
|
|
return (
|
|
(test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")),
|
|
(test_proc.returncode, repro_proc.returncode),
|
|
)
|
|
|
|
def test_torch_compile_after_dynamo_compile_error(self):
|
|
(tb1, tb2), _ = self._test_torch_compile("dynamo", 2, CPP_COMPILE_ERROR)
|
|
self.assertIn("CppCompileError", tb1)
|
|
self.assertIn("CppCompileError", tb2)
|
|
|
|
def test_torch_compile_after_dynamo_accuracy_error(self):
|
|
(tb1, tb2), _ = self._test_torch_compile("dynamo", 4, CPP_ACCURACY_ERROR)
|
|
self.assertIn("AccuracyError", tb1)
|
|
self.assertIn("AccuracyError", tb2)
|
|
|
|
def test_torch_compile_after_aot_compile_error(self):
|
|
(tb1, tb2), _ = self._test_torch_compile("aot", 2, CPP_COMPILE_ERROR)
|
|
self.assertIn("CppCompileError", tb1)
|
|
self.assertIn("CppCompileError", tb2)
|
|
|
|
def test_torch_compile_after_aot_accuracy_error(self):
|
|
(tb1, tb2), _ = self._test_torch_compile("aot", 4, CPP_ACCURACY_ERROR)
|
|
self.assertIn("AccuracyError", tb1)
|
|
self.assertIn("AccuracyError", tb2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
# skip CI tests on mac since CPU inductor does not seem to work due to C++ compile errors
|
|
if not IS_MACOS:
|
|
run_tests()
|