Refactor minifier tests to be more compact (#100471)

Mostly burning in more assumptions based on commonality on the tests,
so writing new tests takes less code.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100471
Approved by: https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang
2023-05-02 11:43:43 -07:00
committed by PyTorch MergeBot
parent 409fc7a4c7
commit 2089a9bd48
4 changed files with 188 additions and 212 deletions

View File

@ -1,7 +1,5 @@
# Owner(s): ["module: dynamo"]
import functools
import re
import textwrap
import unittest
import torch._dynamo
@ -14,64 +12,56 @@ requires_cuda = functools.partial(
class MinifierTests(MinifierTestBase):
# Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA)
def _test_after_dynamo(self, device, repro_level, backend, error_name):
run_code = textwrap.dedent(
f"""\
@torch._dynamo.optimize({backend!r})
def inner(x):
for _ in range(10):
x = torch.sin(x)
x = torch.relu(x)
for _ in range(10):
x = torch.cos(x)
return x
def _test_after_dynamo(self, device, backend, expected_error):
run_code = f"""\
@torch._dynamo.optimize({backend!r})
def inner(x):
for _ in range(10):
x = torch.sin(x)
x = torch.relu(x)
for _ in range(10):
x = torch.cos(x)
return x
inner(torch.randn(20, 20).to("{device}"))
"""
)
test_proc, _, repro_proc = self._run_full_test_nocode(
run_code, "dynamo", repro_level, "", isolate=False
)
self.assertIn(error_name, test_proc.stderr.decode("utf-8"))
self.assertIn(error_name, repro_proc.stderr.decode("utf-8"))
inner(torch.randn(20, 20).to("{device}"))
"""
self._run_full_test(run_code, "dynamo", expected_error, isolate=False)
def test_after_dynamo_cpu_compile_error(self):
self._test_after_dynamo(
"cpu", 2, "relu_compile_error_TESTING_ONLY", "ReluCompileError"
"cpu", "relu_compile_error_TESTING_ONLY", "ReluCompileError"
)
def test_after_dynamo_cpu_runtime_error(self):
self._test_after_dynamo(
"cpu", 2, "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
"cpu", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
)
def test_after_dynamo_cpu_accuracy_error(self):
self._test_after_dynamo(
"cpu", 4, "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
"cpu", "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
)
@requires_cuda()
def test_after_dynamo_cuda_compile_error(self):
self._test_after_dynamo(
"cuda", 2, "relu_compile_error_TESTING_ONLY", "ReluCompileError"
"cuda", "relu_compile_error_TESTING_ONLY", "ReluCompileError"
)
@requires_cuda()
def test_after_dynamo_cuda_runtime_error(self):
self._test_after_dynamo(
"cuda", 2, "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
"cuda", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
)
@requires_cuda()
def test_after_dynamo_cuda_accuracy_error(self):
self._test_after_dynamo(
"cuda", 4, "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
"cuda", "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
)
# Ensure that the testing backends pass when relu is not present.
def _test_after_dynamo_backend_passes(self, device, repro_level, backend):
def _test_after_dynamo_backend_passes(self, device, backend):
@torch._dynamo.optimize(backend)
def inner(x):
for _ in range(10):
@ -83,149 +73,133 @@ class MinifierTests(MinifierTestBase):
inner(torch.randn(20, 20).to(device))
def test_after_dynamo_cpu_compile_backend_passes(self):
self._test_after_dynamo_backend_passes(
"cpu", 2, "relu_compile_error_TESTING_ONLY"
)
self._test_after_dynamo_backend_passes("cpu", "relu_compile_error_TESTING_ONLY")
def test_after_dynamo_cpu_runtime_backend_passes(self):
self._test_after_dynamo_backend_passes(
"cpu", 2, "relu_runtime_error_TESTING_ONLY"
)
self._test_after_dynamo_backend_passes("cpu", "relu_runtime_error_TESTING_ONLY")
def test_after_dynamo_cpu_accuracy_backend_passes(self):
self._test_after_dynamo_backend_passes(
"cpu", 4, "relu_accuracy_error_TESTING_ONLY"
"cpu", "relu_accuracy_error_TESTING_ONLY"
)
@requires_cuda()
def test_after_dynamo_cuda_compile_backend_passes(self):
self._test_after_dynamo_backend_passes(
"cuda", 2, "relu_compile_error_TESTING_ONLY"
"cuda", "relu_compile_error_TESTING_ONLY"
)
@requires_cuda()
def test_after_dynamo_cuda_runtime_backend_passes(self):
self._test_after_dynamo_backend_passes(
"cuda", 2, "relu_runtime_error_TESTING_ONLY"
"cuda", "relu_runtime_error_TESTING_ONLY"
)
@requires_cuda()
def test_after_dynamo_cuda_accuracy_backend_passes(self):
self._test_after_dynamo_backend_passes(
"cuda", 4, "relu_accuracy_error_TESTING_ONLY"
"cuda", "relu_accuracy_error_TESTING_ONLY"
)
# Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd
@requires_cuda()
def test_cpu_cuda_module_after_dynamo(self):
backend_name = "relu_compile_error_TESTING_ONLY"
run_code = textwrap.dedent(
f"""\
class CpuCudaModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.m_x = torch.nn.Linear(20, 20).cuda()
self.m_y = torch.nn.Linear(20, 20)
self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda())
self.p_y = torch.nn.Parameter(torch.randn(20, 20))
self.register_buffer("b_x", torch.ones(20, 20).cuda())
self.register_buffer("b_y", torch.ones(20, 20))
run_code = f"""\
class CpuCudaModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.m_x = torch.nn.Linear(20, 20).cuda()
self.m_y = torch.nn.Linear(20, 20)
self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda())
self.p_y = torch.nn.Parameter(torch.randn(20, 20))
self.register_buffer("b_x", torch.ones(20, 20).cuda())
self.register_buffer("b_y", torch.ones(20, 20))
def forward(self, x, y):
return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y
def forward(self, x, y):
return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y
mod = CpuCudaModule()
mod = CpuCudaModule()
@torch._dynamo.optimize({backend_name!r})
def inner(x1, y1):
x2 = torch.randn(20, 20).cuda()
y2 = torch.randn(20, 20)
x3, y3 = mod(x1 + x2, y1 + y2)
return torch.relu(x3.cpu() + y3)
@torch._dynamo.optimize({backend_name!r})
def inner(x1, y1):
x2 = torch.randn(20, 20).cuda()
y2 = torch.randn(20, 20)
x3, y3 = mod(x1 + x2, y1 + y2)
return torch.relu(x3.cpu() + y3)
inner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
"""
inner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
"""
res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False)
self.assertExpectedInline(
res.minifier_module(),
"""\
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda()
self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True)
self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).cuda())
self.register_buffer('G__mod___b_y', torch.randn([20, 20], dtype=torch.float32))
self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32)).cuda()
self.G__mod___p_y = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32))
def forward(self, L_x1_ : torch.Tensor, L_y1_ : torch.Tensor):
l_x1_ = L_x1_
l_y1_ = L_y1_
randn = torch.randn(20, 20)
cuda = randn.cuda(); randn = None
randn_1 = torch.randn(20, 20)
add = l_x1_ + cuda; l_x1_ = cuda = None
add_1 = l_y1_ + randn_1; l_y1_ = randn_1 = None
g__mod___m_x = self.G__mod___m_x(add); add = None
g__mod___p_x = self.G__mod___p_x
add_2 = g__mod___m_x + g__mod___p_x; g__mod___m_x = g__mod___p_x = None
g__mod___b_x = self.G__mod___b_x
add_3 = add_2 + g__mod___b_x; add_2 = g__mod___b_x = None
g__mod___m_y = self.G__mod___m_y(add_1); add_1 = None
g__mod___p_y = self.G__mod___p_y
add_4 = g__mod___m_y + g__mod___p_y; g__mod___m_y = g__mod___p_y = None
g__mod___b_y = self.G__mod___b_y
add_5 = add_4 + g__mod___b_y; add_4 = g__mod___b_y = None
cpu = add_3.cpu(); add_3 = None
add_6 = cpu + add_5; cpu = add_5 = None
relu = torch.relu(add_6); add_6 = None
return (relu,)""",
)
(test_proc, _, repro_proc), (launch_code, _) = self._run_full_test(
run_code, "dynamo", 2, "", isolate=False
)
tb1 = test_proc.stderr.decode("utf-8")
tb2 = repro_proc.stderr.decode("utf-8")
# Check if generated minifier code covers all cpu/cuda cases
self.assertIsNotNone(re.search(r"args.*cuda", launch_code))
self.assertIsNotNone(re.search(r"args.*cpu", launch_code))
# search for Linear(...).cuda()
self.assertIsNotNone(re.search(r"Linear.*cuda", launch_code))
# search for Linear(...)
self.assertIsNotNone(
re.search(r"Linear(?!.*cuda.*$)", launch_code, re.MULTILINE)
)
self.assertIsNotNone(re.search(r"register_buffer.*cuda", launch_code))
self.assertIsNotNone(
re.search(r"register_buffer(?!.*cuda.*$)", launch_code, re.MULTILINE)
)
self.assertIsNotNone(re.search(r"Parameter.*cuda", launch_code))
self.assertIsNotNone(
re.search(r"Parameter(?!.*cuda.*$)", launch_code, re.MULTILINE)
)
# search for
# <name> = torch.randn(...)
# ... = <name>.cuda()
self.assertIsNotNone(
re.search(r"(\w+) = torch.randn.*\1\.cuda", launch_code, re.DOTALL)
)
# search for
# <name> = torch.randn(...)
# no followup call to <name>.cuda()
self.assertIsNotNone(
re.search(
r"(\w+) = torch.randn(?!.*\1\.cuda\(\).*$)", launch_code, re.DOTALL
)
)
self.assertIn(backend_name, tb1)
self.assertIn(backend_name, tb2)
# Test if we can actually get a minified graph
def test_if_graph_minified(self):
backend_name = "relu_compile_error_TESTING_ONLY"
run_code = textwrap.dedent(
f"""\
@torch._dynamo.optimize({backend_name!r})
def inner(x):
for _ in range(20):
x = torch.sin(x)
x = torch.relu(x)
for _ in range(20):
x = torch.cos(x)
return x
run_code = f"""\
@torch._dynamo.optimize({backend_name!r})
def inner(x):
for _ in range(20):
x = torch.sin(x)
x = torch.relu(x)
for _ in range(20):
x = torch.cos(x)
return x
inner(torch.randn(20, 20))
"""
inner(torch.randn(20, 20))
"""
res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False)
self.assertExpectedInline(
res.repro_module(),
"""\
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, cos, sin_19):
relu = torch.relu(sin_19); sin_19 = None
return (cos,)""",
)
(test_proc, _, repro_proc), (launch_code, repro_code) = self._run_full_test(
run_code, "dynamo", 2, "", isolate=False
)
tb1 = test_proc.stderr.decode("utf-8")
tb2 = repro_proc.stderr.decode("utf-8")
self.assertIn(backend_name, tb1)
self.assertIn(backend_name, tb2)
# compare the length of the forward functions
match = re.search(r"def forward.*return", launch_code, re.DOTALL)
self.assertIsNotNone(match)
self.assertGreater(match.group(0).count("\n"), 40)
match = re.search(r"def forward.*return", repro_code, re.DOTALL)
self.assertIsNotNone(match)
self.assertLess(match.group(0).count("\n"), 5)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1,10 +1,10 @@
# Owner(s): ["module: inductor"]
import functools
import textwrap
import unittest
import torch
import torch._dynamo
import torch._inductor.config as inductor_config
import torch._inductor.utils
from torch._dynamo.test_minifier_common import MinifierTestBase
from torch.testing._internal.common_utils import IS_JETSON, IS_MACOS, TEST_WITH_ASAN
@ -15,54 +15,40 @@ requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cu
class MinifierTests(MinifierTestBase):
# Test that compile and accuracy errors after aot can be repro'd (both CPU and CUDA)
def _test_after_aot(self, device, bug_type, repro_level):
def _test_after_aot(self, device, expected_error):
# NB: The program is intentionally quite simple, just enough to
# trigger one minification step, no more (dedicated minifier tests
# should exercise minifier only)
run_code = textwrap.dedent(
f"""\
@torch.compile()
def inner(x):
x = torch.relu(x)
x = torch.cos(x)
return x
run_code = f"""\
@torch.compile()
def inner(x):
x = torch.relu(x)
x = torch.cos(x)
return x
inner(torch.randn(20, 20).to("{device}"))
"""
)
# These will crash the process and should be tested in
# test_minifier_isolate.py
assert bug_type != "runtime_error"
patch_code = self._gen_codegen_fn_patch_code(device, bug_type)
self.assertIsNotNone(patch_code)
test_proc, _, repro_proc = self._run_full_test_nocode(
run_code, "aot", repro_level, patch_code, isolate=False
)
return test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")
inner(torch.randn(20, 20).to("{device}"))
"""
self._run_full_test(run_code, "aot", expected_error, isolate=False)
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "compile_error")
def test_after_aot_cpu_compile_error(self):
tb1, tb2 = self._test_after_aot("cpu", "compile_error", 2)
self.assertIn("CppCompileError", tb1)
self.assertIn("CppCompileError", tb2)
self._test_after_aot("cpu", "CppCompileError")
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
def test_after_aot_cpu_accuracy_error(self):
tb1, tb2 = self._test_after_aot("cpu", "accuracy", 4)
self.assertIn("AccuracyError", tb1)
self.assertIn("AccuracyError", tb2)
self._test_after_aot("cpu", "AccuracyError")
@requires_cuda()
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "compile_error")
def test_after_aot_cuda_compile_error(self):
tb1, tb2 = self._test_after_aot("cuda", "compile_error", 2)
self.assertIn("SyntaxError", tb1)
self.assertIn("SyntaxError", tb2)
self._test_after_aot("cuda", "SyntaxError")
@requires_cuda()
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy")
def test_after_aot_cuda_accuracy_error(self):
tb1, tb2 = self._test_after_aot("cuda", "accuracy", 4)
self.assertIn("AccuracyError", tb1)
self.assertIn("AccuracyError", tb2)
self._test_after_aot("cuda", "AccuracyError")
if __name__ == "__main__":

View File

@ -1,10 +1,10 @@
# Owner(s): ["module: inductor"]
import functools
import textwrap
import unittest
import torch
import torch._dynamo
import torch._inductor.config as inductor_config
import torch._inductor.utils
from torch._dynamo.test_minifier_common import MinifierTestBase
from torch.testing._internal.common_utils import IS_JETSON, IS_MACOS, TEST_WITH_ASAN
@ -16,38 +16,28 @@ requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cu
# These minifier tests are slow, because they must be run in separate
# subprocesses
class MinifierIsolateTests(MinifierTestBase):
def _test_after_aot_runtime_error(self, device, bug_type):
run_code = textwrap.dedent(
f"""\
@torch.compile()
def inner(x):
x = torch.relu(x)
x = torch.cos(x)
return x
inner(torch.randn(20, 20).to("{device}"))
"""
)
patch_code = self._gen_codegen_fn_patch_code(device, bug_type)
self.assertIsNotNone(patch_code)
def _test_after_aot_runtime_error(self, device, expected_error):
run_code = f"""\
@torch.compile()
def inner(x):
x = torch.relu(x)
x = torch.cos(x)
return x
inner(torch.randn(20, 20).to("{device}"))
"""
# These must isolate because they crash the process
test_proc, _, repro_proc = self._run_full_test_nocode(
run_code, "aot", 3, patch_code, isolate=True
)
self.assertNotIn("CompilerError", test_proc.stderr.decode("utf-8"))
self.assertEqual(test_proc.returncode, repro_proc.returncode)
self.assertNotEqual(test_proc.returncode, 0)
self._run_full_test(run_code, "aot", expected_error, isolate=True)
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "runtime_error")
def test_after_aot_cpu_runtime_error(self):
self._test_after_aot_runtime_error("cpu", "runtime_error")
self._test_after_aot_runtime_error("cpu", "")
@requires_cuda()
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "runtime_error")
def test_after_aot_cuda_runtime_error(self):
self._test_after_aot_runtime_error("cuda", "runtime_error")
self._test_after_aot_runtime_error("cuda", "device-side assert")
if __name__ == "__main__":

View File

@ -1,9 +1,11 @@
import dataclasses
import io
import logging
import os
import re
import shutil
import subprocess
import sys
import tempfile
import traceback
from unittest.mock import patch
@ -13,6 +15,26 @@ import torch._dynamo
import torch._dynamo.test_case
@dataclasses.dataclass
class MinifierTestResult:
minifier_code: str
repro_code: str
def _get_module(self, t):
r = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t).group(
0
)
r = re.sub(r"\s+$", "\n", r, flags=re.MULTILINE)
r = re.sub(r"\n{3,}", "\n\n", r)
return r.strip()
def minifier_module(self):
return self._get_module(self.minifier_code)
def repro_module(self):
return self._get_module(self.repro_code)
class MinifierTestBase(torch._dynamo.test_case.TestCase):
DEBUG_DIR = tempfile.mkdtemp()
@ -95,6 +117,9 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
log.removeHandler(log_handler)
if cwd is not None:
os.chdir(prev_cwd)
# Make sure we don't leave buggy compiled frames lying
# around
torch._dynamo.reset()
finally:
object.__setattr__(torch._dynamo.config, "_config", dynamo_config)
object.__setattr__(torch._inductor.config, "_config", inductor_config)
@ -162,11 +187,12 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
# `run_code` is the code to run for the test case.
# `patch_code` is the code to be patched in every generated file; usually
# just use this to turn on bugs via the config
def _gen_test_code(self, run_code, repro_after, repro_level, patch_code):
def _gen_test_code(self, run_code, repro_after, repro_level):
return f"""\
import torch
import torch._dynamo
{patch_code}
{torch._dynamo.config.codegen_config()}
{torch._inductor.config.codegen_config()}
torch._dynamo.config.repro_after = "{repro_after}"
torch._dynamo.config.repro_level = {repro_level}
torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}"
@ -175,31 +201,31 @@ torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}"
# Runs a full minifier test.
# Minifier tests generally consist of 3 stages:
# 1. Run the problematic code (in a separate process since it could segfault)
# 1. Run the problematic code
# 2. Run the generated minifier launcher script
# 3. Run the generated repro script
#
# If possible, you should run the test with isolate=False; use
# isolate=True only if the bug you're testing would otherwise
# crash the process
def _run_full_test(
self, run_code, repro_after, repro_level, patch_code, *, isolate
):
test_code = self._gen_test_code(run_code, repro_after, repro_level, patch_code)
def _run_full_test(self, run_code, repro_after, expected_error, *, isolate):
if isolate:
repro_level = 3
else:
repro_level = 4 if expected_error == "AccuracyError" else 2
test_code = self._gen_test_code(run_code, repro_after, repro_level)
print("running test", file=sys.stderr)
test_proc, repro_dir = self._run_test_code(test_code, isolate=isolate)
# NB: Intentionally do not test return code; we only care about
# actually generating the repro, we don't have to crash
self.assertIn(expected_error, test_proc.stderr.decode("utf-8"))
self.assertIsNotNone(repro_dir)
print("running minifier")
launch_proc, launch_code = self._run_minifier_launcher(
print("running minifier", file=sys.stderr)
minifier_proc, minifier_code = self._run_minifier_launcher(
repro_dir, isolate=isolate
)
print("running repro")
print("running repro", file=sys.stderr)
repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate)
return (test_proc, launch_proc, repro_proc), (launch_code, repro_code)
def _run_full_test_nocode(
self, run_code, repro_after, repro_level, patch_code, *, isolate
):
tbs, _ = self._run_full_test(
run_code, repro_after, repro_level, patch_code, isolate=isolate
)
return tbs
self.assertIn(expected_error, repro_proc.stderr.decode("utf-8"))
self.assertNotEqual(repro_proc.returncode, 0)
return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code)