mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor minifier tests to be more compact (#100471)
Mostly burning in more assumptions based on commonality on the tests, so writing new tests takes less code. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/100471 Approved by: https://github.com/voznesenskym
This commit is contained in:
committed by
PyTorch MergeBot
parent
409fc7a4c7
commit
2089a9bd48
@ -1,7 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import functools
|
||||
import re
|
||||
import textwrap
|
||||
import unittest
|
||||
|
||||
import torch._dynamo
|
||||
@ -14,64 +12,56 @@ requires_cuda = functools.partial(
|
||||
|
||||
class MinifierTests(MinifierTestBase):
|
||||
# Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA)
|
||||
def _test_after_dynamo(self, device, repro_level, backend, error_name):
|
||||
run_code = textwrap.dedent(
|
||||
f"""\
|
||||
@torch._dynamo.optimize({backend!r})
|
||||
def inner(x):
|
||||
for _ in range(10):
|
||||
x = torch.sin(x)
|
||||
x = torch.relu(x)
|
||||
for _ in range(10):
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
def _test_after_dynamo(self, device, backend, expected_error):
|
||||
run_code = f"""\
|
||||
@torch._dynamo.optimize({backend!r})
|
||||
def inner(x):
|
||||
for _ in range(10):
|
||||
x = torch.sin(x)
|
||||
x = torch.relu(x)
|
||||
for _ in range(10):
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
inner(torch.randn(20, 20).to("{device}"))
|
||||
"""
|
||||
)
|
||||
|
||||
test_proc, _, repro_proc = self._run_full_test_nocode(
|
||||
run_code, "dynamo", repro_level, "", isolate=False
|
||||
)
|
||||
|
||||
self.assertIn(error_name, test_proc.stderr.decode("utf-8"))
|
||||
self.assertIn(error_name, repro_proc.stderr.decode("utf-8"))
|
||||
inner(torch.randn(20, 20).to("{device}"))
|
||||
"""
|
||||
self._run_full_test(run_code, "dynamo", expected_error, isolate=False)
|
||||
|
||||
def test_after_dynamo_cpu_compile_error(self):
|
||||
self._test_after_dynamo(
|
||||
"cpu", 2, "relu_compile_error_TESTING_ONLY", "ReluCompileError"
|
||||
"cpu", "relu_compile_error_TESTING_ONLY", "ReluCompileError"
|
||||
)
|
||||
|
||||
def test_after_dynamo_cpu_runtime_error(self):
|
||||
self._test_after_dynamo(
|
||||
"cpu", 2, "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
|
||||
"cpu", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
|
||||
)
|
||||
|
||||
def test_after_dynamo_cpu_accuracy_error(self):
|
||||
self._test_after_dynamo(
|
||||
"cpu", 4, "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
|
||||
"cpu", "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
|
||||
)
|
||||
|
||||
@requires_cuda()
|
||||
def test_after_dynamo_cuda_compile_error(self):
|
||||
self._test_after_dynamo(
|
||||
"cuda", 2, "relu_compile_error_TESTING_ONLY", "ReluCompileError"
|
||||
"cuda", "relu_compile_error_TESTING_ONLY", "ReluCompileError"
|
||||
)
|
||||
|
||||
@requires_cuda()
|
||||
def test_after_dynamo_cuda_runtime_error(self):
|
||||
self._test_after_dynamo(
|
||||
"cuda", 2, "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
|
||||
"cuda", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError"
|
||||
)
|
||||
|
||||
@requires_cuda()
|
||||
def test_after_dynamo_cuda_accuracy_error(self):
|
||||
self._test_after_dynamo(
|
||||
"cuda", 4, "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
|
||||
"cuda", "relu_accuracy_error_TESTING_ONLY", "AccuracyError"
|
||||
)
|
||||
|
||||
# Ensure that the testing backends pass when relu is not present.
|
||||
def _test_after_dynamo_backend_passes(self, device, repro_level, backend):
|
||||
def _test_after_dynamo_backend_passes(self, device, backend):
|
||||
@torch._dynamo.optimize(backend)
|
||||
def inner(x):
|
||||
for _ in range(10):
|
||||
@ -83,149 +73,133 @@ class MinifierTests(MinifierTestBase):
|
||||
inner(torch.randn(20, 20).to(device))
|
||||
|
||||
def test_after_dynamo_cpu_compile_backend_passes(self):
|
||||
self._test_after_dynamo_backend_passes(
|
||||
"cpu", 2, "relu_compile_error_TESTING_ONLY"
|
||||
)
|
||||
self._test_after_dynamo_backend_passes("cpu", "relu_compile_error_TESTING_ONLY")
|
||||
|
||||
def test_after_dynamo_cpu_runtime_backend_passes(self):
|
||||
self._test_after_dynamo_backend_passes(
|
||||
"cpu", 2, "relu_runtime_error_TESTING_ONLY"
|
||||
)
|
||||
self._test_after_dynamo_backend_passes("cpu", "relu_runtime_error_TESTING_ONLY")
|
||||
|
||||
def test_after_dynamo_cpu_accuracy_backend_passes(self):
|
||||
self._test_after_dynamo_backend_passes(
|
||||
"cpu", 4, "relu_accuracy_error_TESTING_ONLY"
|
||||
"cpu", "relu_accuracy_error_TESTING_ONLY"
|
||||
)
|
||||
|
||||
@requires_cuda()
|
||||
def test_after_dynamo_cuda_compile_backend_passes(self):
|
||||
self._test_after_dynamo_backend_passes(
|
||||
"cuda", 2, "relu_compile_error_TESTING_ONLY"
|
||||
"cuda", "relu_compile_error_TESTING_ONLY"
|
||||
)
|
||||
|
||||
@requires_cuda()
|
||||
def test_after_dynamo_cuda_runtime_backend_passes(self):
|
||||
self._test_after_dynamo_backend_passes(
|
||||
"cuda", 2, "relu_runtime_error_TESTING_ONLY"
|
||||
"cuda", "relu_runtime_error_TESTING_ONLY"
|
||||
)
|
||||
|
||||
@requires_cuda()
|
||||
def test_after_dynamo_cuda_accuracy_backend_passes(self):
|
||||
self._test_after_dynamo_backend_passes(
|
||||
"cuda", 4, "relu_accuracy_error_TESTING_ONLY"
|
||||
"cuda", "relu_accuracy_error_TESTING_ONLY"
|
||||
)
|
||||
|
||||
# Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd
|
||||
@requires_cuda()
|
||||
def test_cpu_cuda_module_after_dynamo(self):
|
||||
backend_name = "relu_compile_error_TESTING_ONLY"
|
||||
run_code = textwrap.dedent(
|
||||
f"""\
|
||||
class CpuCudaModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.m_x = torch.nn.Linear(20, 20).cuda()
|
||||
self.m_y = torch.nn.Linear(20, 20)
|
||||
self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda())
|
||||
self.p_y = torch.nn.Parameter(torch.randn(20, 20))
|
||||
self.register_buffer("b_x", torch.ones(20, 20).cuda())
|
||||
self.register_buffer("b_y", torch.ones(20, 20))
|
||||
run_code = f"""\
|
||||
class CpuCudaModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.m_x = torch.nn.Linear(20, 20).cuda()
|
||||
self.m_y = torch.nn.Linear(20, 20)
|
||||
self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda())
|
||||
self.p_y = torch.nn.Parameter(torch.randn(20, 20))
|
||||
self.register_buffer("b_x", torch.ones(20, 20).cuda())
|
||||
self.register_buffer("b_y", torch.ones(20, 20))
|
||||
|
||||
def forward(self, x, y):
|
||||
return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y
|
||||
def forward(self, x, y):
|
||||
return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y
|
||||
|
||||
mod = CpuCudaModule()
|
||||
mod = CpuCudaModule()
|
||||
|
||||
@torch._dynamo.optimize({backend_name!r})
|
||||
def inner(x1, y1):
|
||||
x2 = torch.randn(20, 20).cuda()
|
||||
y2 = torch.randn(20, 20)
|
||||
x3, y3 = mod(x1 + x2, y1 + y2)
|
||||
return torch.relu(x3.cpu() + y3)
|
||||
@torch._dynamo.optimize({backend_name!r})
|
||||
def inner(x1, y1):
|
||||
x2 = torch.randn(20, 20).cuda()
|
||||
y2 = torch.randn(20, 20)
|
||||
x3, y3 = mod(x1 + x2, y1 + y2)
|
||||
return torch.relu(x3.cpu() + y3)
|
||||
|
||||
inner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
|
||||
"""
|
||||
inner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
|
||||
"""
|
||||
|
||||
res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False)
|
||||
|
||||
self.assertExpectedInline(
|
||||
res.minifier_module(),
|
||||
"""\
|
||||
class Repro(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda()
|
||||
self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True)
|
||||
self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).cuda())
|
||||
self.register_buffer('G__mod___b_y', torch.randn([20, 20], dtype=torch.float32))
|
||||
self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32)).cuda()
|
||||
self.G__mod___p_y = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32))
|
||||
|
||||
def forward(self, L_x1_ : torch.Tensor, L_y1_ : torch.Tensor):
|
||||
l_x1_ = L_x1_
|
||||
l_y1_ = L_y1_
|
||||
randn = torch.randn(20, 20)
|
||||
cuda = randn.cuda(); randn = None
|
||||
randn_1 = torch.randn(20, 20)
|
||||
add = l_x1_ + cuda; l_x1_ = cuda = None
|
||||
add_1 = l_y1_ + randn_1; l_y1_ = randn_1 = None
|
||||
g__mod___m_x = self.G__mod___m_x(add); add = None
|
||||
g__mod___p_x = self.G__mod___p_x
|
||||
add_2 = g__mod___m_x + g__mod___p_x; g__mod___m_x = g__mod___p_x = None
|
||||
g__mod___b_x = self.G__mod___b_x
|
||||
add_3 = add_2 + g__mod___b_x; add_2 = g__mod___b_x = None
|
||||
g__mod___m_y = self.G__mod___m_y(add_1); add_1 = None
|
||||
g__mod___p_y = self.G__mod___p_y
|
||||
add_4 = g__mod___m_y + g__mod___p_y; g__mod___m_y = g__mod___p_y = None
|
||||
g__mod___b_y = self.G__mod___b_y
|
||||
add_5 = add_4 + g__mod___b_y; add_4 = g__mod___b_y = None
|
||||
cpu = add_3.cpu(); add_3 = None
|
||||
add_6 = cpu + add_5; cpu = add_5 = None
|
||||
relu = torch.relu(add_6); add_6 = None
|
||||
return (relu,)""",
|
||||
)
|
||||
|
||||
(test_proc, _, repro_proc), (launch_code, _) = self._run_full_test(
|
||||
run_code, "dynamo", 2, "", isolate=False
|
||||
)
|
||||
|
||||
tb1 = test_proc.stderr.decode("utf-8")
|
||||
tb2 = repro_proc.stderr.decode("utf-8")
|
||||
|
||||
# Check if generated minifier code covers all cpu/cuda cases
|
||||
self.assertIsNotNone(re.search(r"args.*cuda", launch_code))
|
||||
self.assertIsNotNone(re.search(r"args.*cpu", launch_code))
|
||||
# search for Linear(...).cuda()
|
||||
self.assertIsNotNone(re.search(r"Linear.*cuda", launch_code))
|
||||
# search for Linear(...)
|
||||
self.assertIsNotNone(
|
||||
re.search(r"Linear(?!.*cuda.*$)", launch_code, re.MULTILINE)
|
||||
)
|
||||
self.assertIsNotNone(re.search(r"register_buffer.*cuda", launch_code))
|
||||
self.assertIsNotNone(
|
||||
re.search(r"register_buffer(?!.*cuda.*$)", launch_code, re.MULTILINE)
|
||||
)
|
||||
self.assertIsNotNone(re.search(r"Parameter.*cuda", launch_code))
|
||||
self.assertIsNotNone(
|
||||
re.search(r"Parameter(?!.*cuda.*$)", launch_code, re.MULTILINE)
|
||||
)
|
||||
# search for
|
||||
# <name> = torch.randn(...)
|
||||
# ... = <name>.cuda()
|
||||
self.assertIsNotNone(
|
||||
re.search(r"(\w+) = torch.randn.*\1\.cuda", launch_code, re.DOTALL)
|
||||
)
|
||||
# search for
|
||||
# <name> = torch.randn(...)
|
||||
# no followup call to <name>.cuda()
|
||||
self.assertIsNotNone(
|
||||
re.search(
|
||||
r"(\w+) = torch.randn(?!.*\1\.cuda\(\).*$)", launch_code, re.DOTALL
|
||||
)
|
||||
)
|
||||
|
||||
self.assertIn(backend_name, tb1)
|
||||
self.assertIn(backend_name, tb2)
|
||||
|
||||
# Test if we can actually get a minified graph
|
||||
def test_if_graph_minified(self):
|
||||
backend_name = "relu_compile_error_TESTING_ONLY"
|
||||
run_code = textwrap.dedent(
|
||||
f"""\
|
||||
@torch._dynamo.optimize({backend_name!r})
|
||||
def inner(x):
|
||||
for _ in range(20):
|
||||
x = torch.sin(x)
|
||||
x = torch.relu(x)
|
||||
for _ in range(20):
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
run_code = f"""\
|
||||
@torch._dynamo.optimize({backend_name!r})
|
||||
def inner(x):
|
||||
for _ in range(20):
|
||||
x = torch.sin(x)
|
||||
x = torch.relu(x)
|
||||
for _ in range(20):
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
inner(torch.randn(20, 20))
|
||||
"""
|
||||
inner(torch.randn(20, 20))
|
||||
"""
|
||||
|
||||
res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False)
|
||||
|
||||
self.assertExpectedInline(
|
||||
res.repro_module(),
|
||||
"""\
|
||||
class Repro(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, cos, sin_19):
|
||||
relu = torch.relu(sin_19); sin_19 = None
|
||||
return (cos,)""",
|
||||
)
|
||||
|
||||
(test_proc, _, repro_proc), (launch_code, repro_code) = self._run_full_test(
|
||||
run_code, "dynamo", 2, "", isolate=False
|
||||
)
|
||||
|
||||
tb1 = test_proc.stderr.decode("utf-8")
|
||||
tb2 = repro_proc.stderr.decode("utf-8")
|
||||
|
||||
self.assertIn(backend_name, tb1)
|
||||
self.assertIn(backend_name, tb2)
|
||||
|
||||
# compare the length of the forward functions
|
||||
match = re.search(r"def forward.*return", launch_code, re.DOTALL)
|
||||
self.assertIsNotNone(match)
|
||||
self.assertGreater(match.group(0).count("\n"), 40)
|
||||
|
||||
match = re.search(r"def forward.*return", repro_code, re.DOTALL)
|
||||
self.assertIsNotNone(match)
|
||||
self.assertLess(match.group(0).count("\n"), 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import functools
|
||||
import textwrap
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._inductor.config as inductor_config
|
||||
import torch._inductor.utils
|
||||
from torch._dynamo.test_minifier_common import MinifierTestBase
|
||||
from torch.testing._internal.common_utils import IS_JETSON, IS_MACOS, TEST_WITH_ASAN
|
||||
@ -15,54 +15,40 @@ requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cu
|
||||
|
||||
class MinifierTests(MinifierTestBase):
|
||||
# Test that compile and accuracy errors after aot can be repro'd (both CPU and CUDA)
|
||||
def _test_after_aot(self, device, bug_type, repro_level):
|
||||
def _test_after_aot(self, device, expected_error):
|
||||
# NB: The program is intentionally quite simple, just enough to
|
||||
# trigger one minification step, no more (dedicated minifier tests
|
||||
# should exercise minifier only)
|
||||
run_code = textwrap.dedent(
|
||||
f"""\
|
||||
@torch.compile()
|
||||
def inner(x):
|
||||
x = torch.relu(x)
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
run_code = f"""\
|
||||
@torch.compile()
|
||||
def inner(x):
|
||||
x = torch.relu(x)
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
inner(torch.randn(20, 20).to("{device}"))
|
||||
"""
|
||||
)
|
||||
# These will crash the process and should be tested in
|
||||
# test_minifier_isolate.py
|
||||
assert bug_type != "runtime_error"
|
||||
patch_code = self._gen_codegen_fn_patch_code(device, bug_type)
|
||||
self.assertIsNotNone(patch_code)
|
||||
test_proc, _, repro_proc = self._run_full_test_nocode(
|
||||
run_code, "aot", repro_level, patch_code, isolate=False
|
||||
)
|
||||
return test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")
|
||||
inner(torch.randn(20, 20).to("{device}"))
|
||||
"""
|
||||
self._run_full_test(run_code, "aot", expected_error, isolate=False)
|
||||
|
||||
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
|
||||
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "compile_error")
|
||||
def test_after_aot_cpu_compile_error(self):
|
||||
tb1, tb2 = self._test_after_aot("cpu", "compile_error", 2)
|
||||
self.assertIn("CppCompileError", tb1)
|
||||
self.assertIn("CppCompileError", tb2)
|
||||
self._test_after_aot("cpu", "CppCompileError")
|
||||
|
||||
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
|
||||
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
|
||||
def test_after_aot_cpu_accuracy_error(self):
|
||||
tb1, tb2 = self._test_after_aot("cpu", "accuracy", 4)
|
||||
self.assertIn("AccuracyError", tb1)
|
||||
self.assertIn("AccuracyError", tb2)
|
||||
self._test_after_aot("cpu", "AccuracyError")
|
||||
|
||||
@requires_cuda()
|
||||
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "compile_error")
|
||||
def test_after_aot_cuda_compile_error(self):
|
||||
tb1, tb2 = self._test_after_aot("cuda", "compile_error", 2)
|
||||
self.assertIn("SyntaxError", tb1)
|
||||
self.assertIn("SyntaxError", tb2)
|
||||
self._test_after_aot("cuda", "SyntaxError")
|
||||
|
||||
@requires_cuda()
|
||||
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy")
|
||||
def test_after_aot_cuda_accuracy_error(self):
|
||||
tb1, tb2 = self._test_after_aot("cuda", "accuracy", 4)
|
||||
self.assertIn("AccuracyError", tb1)
|
||||
self.assertIn("AccuracyError", tb2)
|
||||
self._test_after_aot("cuda", "AccuracyError")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import functools
|
||||
import textwrap
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._inductor.config as inductor_config
|
||||
import torch._inductor.utils
|
||||
from torch._dynamo.test_minifier_common import MinifierTestBase
|
||||
from torch.testing._internal.common_utils import IS_JETSON, IS_MACOS, TEST_WITH_ASAN
|
||||
@ -16,38 +16,28 @@ requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cu
|
||||
# These minifier tests are slow, because they must be run in separate
|
||||
# subprocesses
|
||||
class MinifierIsolateTests(MinifierTestBase):
|
||||
def _test_after_aot_runtime_error(self, device, bug_type):
|
||||
run_code = textwrap.dedent(
|
||||
f"""\
|
||||
@torch.compile()
|
||||
def inner(x):
|
||||
x = torch.relu(x)
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
inner(torch.randn(20, 20).to("{device}"))
|
||||
"""
|
||||
)
|
||||
patch_code = self._gen_codegen_fn_patch_code(device, bug_type)
|
||||
self.assertIsNotNone(patch_code)
|
||||
def _test_after_aot_runtime_error(self, device, expected_error):
|
||||
run_code = f"""\
|
||||
@torch.compile()
|
||||
def inner(x):
|
||||
x = torch.relu(x)
|
||||
x = torch.cos(x)
|
||||
return x
|
||||
|
||||
inner(torch.randn(20, 20).to("{device}"))
|
||||
"""
|
||||
# These must isolate because they crash the process
|
||||
test_proc, _, repro_proc = self._run_full_test_nocode(
|
||||
run_code, "aot", 3, patch_code, isolate=True
|
||||
)
|
||||
|
||||
self.assertNotIn("CompilerError", test_proc.stderr.decode("utf-8"))
|
||||
|
||||
self.assertEqual(test_proc.returncode, repro_proc.returncode)
|
||||
self.assertNotEqual(test_proc.returncode, 0)
|
||||
self._run_full_test(run_code, "aot", expected_error, isolate=True)
|
||||
|
||||
@unittest.skipIf(IS_JETSON, "Fails on Jetson")
|
||||
@inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "runtime_error")
|
||||
def test_after_aot_cpu_runtime_error(self):
|
||||
self._test_after_aot_runtime_error("cpu", "runtime_error")
|
||||
self._test_after_aot_runtime_error("cpu", "")
|
||||
|
||||
@requires_cuda()
|
||||
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "runtime_error")
|
||||
def test_after_aot_cuda_runtime_error(self):
|
||||
self._test_after_aot_runtime_error("cuda", "runtime_error")
|
||||
self._test_after_aot_runtime_error("cuda", "device-side assert")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,9 +1,11 @@
|
||||
import dataclasses
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
from unittest.mock import patch
|
||||
@ -13,6 +15,26 @@ import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MinifierTestResult:
|
||||
minifier_code: str
|
||||
repro_code: str
|
||||
|
||||
def _get_module(self, t):
|
||||
r = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t).group(
|
||||
0
|
||||
)
|
||||
r = re.sub(r"\s+$", "\n", r, flags=re.MULTILINE)
|
||||
r = re.sub(r"\n{3,}", "\n\n", r)
|
||||
return r.strip()
|
||||
|
||||
def minifier_module(self):
|
||||
return self._get_module(self.minifier_code)
|
||||
|
||||
def repro_module(self):
|
||||
return self._get_module(self.repro_code)
|
||||
|
||||
|
||||
class MinifierTestBase(torch._dynamo.test_case.TestCase):
|
||||
DEBUG_DIR = tempfile.mkdtemp()
|
||||
|
||||
@ -95,6 +117,9 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
log.removeHandler(log_handler)
|
||||
if cwd is not None:
|
||||
os.chdir(prev_cwd)
|
||||
# Make sure we don't leave buggy compiled frames lying
|
||||
# around
|
||||
torch._dynamo.reset()
|
||||
finally:
|
||||
object.__setattr__(torch._dynamo.config, "_config", dynamo_config)
|
||||
object.__setattr__(torch._inductor.config, "_config", inductor_config)
|
||||
@ -162,11 +187,12 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
# `run_code` is the code to run for the test case.
|
||||
# `patch_code` is the code to be patched in every generated file; usually
|
||||
# just use this to turn on bugs via the config
|
||||
def _gen_test_code(self, run_code, repro_after, repro_level, patch_code):
|
||||
def _gen_test_code(self, run_code, repro_after, repro_level):
|
||||
return f"""\
|
||||
import torch
|
||||
import torch._dynamo
|
||||
{patch_code}
|
||||
{torch._dynamo.config.codegen_config()}
|
||||
{torch._inductor.config.codegen_config()}
|
||||
torch._dynamo.config.repro_after = "{repro_after}"
|
||||
torch._dynamo.config.repro_level = {repro_level}
|
||||
torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}"
|
||||
@ -175,31 +201,31 @@ torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}"
|
||||
|
||||
# Runs a full minifier test.
|
||||
# Minifier tests generally consist of 3 stages:
|
||||
# 1. Run the problematic code (in a separate process since it could segfault)
|
||||
# 1. Run the problematic code
|
||||
# 2. Run the generated minifier launcher script
|
||||
# 3. Run the generated repro script
|
||||
#
|
||||
# If possible, you should run the test with isolate=False; use
|
||||
# isolate=True only if the bug you're testing would otherwise
|
||||
# crash the process
|
||||
def _run_full_test(
|
||||
self, run_code, repro_after, repro_level, patch_code, *, isolate
|
||||
):
|
||||
test_code = self._gen_test_code(run_code, repro_after, repro_level, patch_code)
|
||||
def _run_full_test(self, run_code, repro_after, expected_error, *, isolate):
|
||||
if isolate:
|
||||
repro_level = 3
|
||||
else:
|
||||
repro_level = 4 if expected_error == "AccuracyError" else 2
|
||||
test_code = self._gen_test_code(run_code, repro_after, repro_level)
|
||||
print("running test", file=sys.stderr)
|
||||
test_proc, repro_dir = self._run_test_code(test_code, isolate=isolate)
|
||||
# NB: Intentionally do not test return code; we only care about
|
||||
# actually generating the repro, we don't have to crash
|
||||
self.assertIn(expected_error, test_proc.stderr.decode("utf-8"))
|
||||
self.assertIsNotNone(repro_dir)
|
||||
print("running minifier")
|
||||
launch_proc, launch_code = self._run_minifier_launcher(
|
||||
print("running minifier", file=sys.stderr)
|
||||
minifier_proc, minifier_code = self._run_minifier_launcher(
|
||||
repro_dir, isolate=isolate
|
||||
)
|
||||
print("running repro")
|
||||
print("running repro", file=sys.stderr)
|
||||
repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate)
|
||||
return (test_proc, launch_proc, repro_proc), (launch_code, repro_code)
|
||||
|
||||
def _run_full_test_nocode(
|
||||
self, run_code, repro_after, repro_level, patch_code, *, isolate
|
||||
):
|
||||
tbs, _ = self._run_full_test(
|
||||
run_code, repro_after, repro_level, patch_code, isolate=isolate
|
||||
)
|
||||
return tbs
|
||||
self.assertIn(expected_error, repro_proc.stderr.decode("utf-8"))
|
||||
self.assertNotEqual(repro_proc.returncode, 0)
|
||||
return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code)
|
||||
|
Reference in New Issue
Block a user