Files
pytorch/test/dynamo/test_package.py

635 lines
23 KiB
Python

# Owner(s): ["module: dynamo"]
import importlib
import os
import sys
import tempfile
import unittest
import torch
import torch._dynamo.testing
import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._functorch import config as functorch_config
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import (
HAS_CUDA_AND_TRITON,
HAS_XPU_AND_TRITON,
)
def compute_loss_helper(x):
return reduce_to_scalar_loss(x)
@functorch_config.patch("bundled_autograd_cache", True)
@torch._dynamo.config.patch({"strict_precompile": True})
@instantiate_parametrized_tests
class TestPackage(torch._inductor.test_case.TestCase):
def path(self):
path = os.path.join(cache_dir(), f"package_{self.id()}")
os.makedirs(path, exist_ok=True)
return path
def setUp(self):
super().setUp()
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
DynamoCache.clear()
PrecompileContext.clear()
def _save_and_reload(self, expected_backends, expected_dynamo):
"""
Serializes all artifacts, clears all caches, then reloads the serialized artifact
Simulates a new process.
Args:
expected_backends: Expected number of precompile_aot_autograd_artifacts
expected_dynamo: Expected number of precompile_dynamo_artifacts
"""
debug_info = PrecompileContext.save_to_dynamo_cache()
self.assertEqual(len(debug_info["dynamo"]), expected_dynamo)
self.assertEqual(len(debug_info["backends"]), expected_backends)
torch._dynamo.reset()
PrecompileContext.clear()
@unittest.expectedFailure # FUNCTION_MATCH guard not serializable today
def test_nn_module(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10, device="cuda")
def forward(self, x):
return self.linear(x)
fn = MyModule()
package = CompilePackage(fn.forward)
compiled_fn = torch._dynamo.optimize("inductor", package=package)(fn)
x = torch.randn(10, 10, device="cuda")
compiled_fn(x)
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_basic_fn(self, backend, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore()
def fn(x):
return x + 1
args = (
torch.randn(
3,
2,
device=device,
),
)
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(backend, package=package)(fn)
expected = compiled_fn(*args)
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(backends)
self.assertEqual(expected, compiled_fn(*args))
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_lazy_backward(self, backend, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore()
def fn(x):
return x.sin() + x.cos()
args = (
torch.zeros(
3,
2,
device=device,
requires_grad=True,
),
)
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(backend, package=package)(fn)
expected = compiled_fn(*args)
expected.sum().backward()
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(backends)
self.assertEqual(expected, compiled_fn(*args))
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_graph_break_bomb(self, backend, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore()
def fn(x, l, r):
if l > r:
return x.sum()
mid = (l + r) // 2
if x.sum() == mid:
return x.sum()
elif x.sum() < mid:
return fn(x, l, mid)
else:
return fn(x, mid + 1, r)
def guard_filter_fn(guards):
return [
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
for guard in guards
]
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(
backend=backend, package=package, guard_filter_fn=guard_filter_fn
)(fn)
N = 10
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
for args in args_list:
compiled_fn(*args)
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
for args in args_list:
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(
backend="eager", package=package, guard_filter_fn=guard_filter_fn
)(fn)
package.install(backends)
for args in args_list:
self.assertEqual(compiled_fn(*args), args[0].sum())
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(torch.tensor(N), 0, N - 1)
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_dynamic_shape(self, backend, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
ctx = DiskDynamoStore()
def fn(x):
return x + x.shape[0]
args = (torch.randn(3, 2, device=device),)
args1 = (torch.randn(5, 2, device=device),)
args2 = (torch.randn(7, 2, device=device),)
expected1 = fn(*args1)
torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5)
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(backend=backend, package=package)(fn)
compiled_fn(*args)
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args1)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(backends)
self.assertEqual(expected1, compiled_fn(*args1))
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args2)
def test_file_change(self):
ctx = DiskDynamoStore()
def import_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
mock_module_add_original = """
def add(x, y):
return x + y
"""
mock_module_add_modified = """
def add(x, y):
return x - y
"""
with tempfile.TemporaryDirectory() as tmp_dir:
mock_module_add_original_path = os.path.join(
tmp_dir, "mock_module_add_original.py"
)
mock_module_add_modified_path = os.path.join(
tmp_dir, "mock_module_add_modified.py"
)
with open(mock_module_add_original_path, "w") as f:
f.write(mock_module_add_original)
with open(mock_module_add_modified_path, "w") as f:
f.write(mock_module_add_modified)
module = import_from_path(
"torch.test_package_helper",
mock_module_add_original_path,
)
def fn(x):
return module.add(x, 1)
args = (torch.randn(3, 2),)
def guard_filter_fn(guards):
return [
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
for guard in guards
]
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(
backend="eager", package=package, guard_filter_fn=guard_filter_fn
)(fn)
compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
module = import_from_path(
"torch.test_package_helper",
mock_module_add_modified_path,
)
with self.assertRaisesRegex(RuntimeError, "Source code changes detected"):
ctx.load_package(fn, self.path())
module = import_from_path(
"torch.test_package_helper",
mock_module_add_original_path,
)
ctx.load_package(fn, self.path())
@parametrize("device", ("cpu", "cuda", "xpu"))
def test_dynamo_cache_manual_load(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
def fn2(x):
return x.cos() + x
package1 = CompilePackage(fn)
package2 = CompilePackage(fn2)
compiled_fn1 = torch._dynamo.optimize(backend="inductor", package=package1)(fn)
compiled_fn2 = torch._dynamo.optimize(backend="inductor", package=package2)(fn2)
arg1 = torch.randn(3, 2, device=device)
arg2 = torch.randn(5, 2, device=device)
expected = [compiled_fn1(arg1), compiled_fn2(arg2)]
DynamoCache.save(package1)
DynamoCache.save(package2)
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=2, expected_dynamo=2)
# These should exist because of populate_caches
package1 = DynamoCache.load_and_install_package(fn)
package2 = DynamoCache.load_and_install_package(fn2)
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1)
result2 = compiled_fn2(arg2)
self.assertEqual(expected, [result1, result2])
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_serialize(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
def fn2(x):
return x.cos() + x
arg1 = torch.randn(3, 2, device=device)
arg2 = torch.randn(5, 2, device=device)
expected = [fn(arg1), fn2(arg2)]
compiled_fn1 = torch.compile(fn)
compiled_fn2 = torch.compile(fn2)
result = [compiled_fn1(arg1), compiled_fn2(arg2)]
self.assertEqual(expected, result)
DynamoCache.clear()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=2, expected_dynamo=2)
compiled_fn1 = torch.compile(fn)
compiled_fn2 = torch.compile(fn2)
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1)
result2 = compiled_fn2(arg2)
self.assertEqual(expected, [result1, result2])
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_recompiles(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
arg1 = torch.randn(3, 2, device=device)
arg2 = torch.randn(5, 2, device=device)
compiled_fn = torch.compile(fn)
expected1 = compiled_fn(arg1)
# Should cause a recompile
expected2 = compiled_fn(arg2)
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=2, expected_dynamo=1)
compiled_fn = torch.compile(fn)
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn(arg1)
result2 = compiled_fn(arg2)
# Because of automatic dynamic, a third random shape should also not cause a recompile
arg3 = torch.randn(7, 2, device=device)
compiled_fn(arg3)
self.assertEqual(result1, expected1)
self.assertEqual(result2, expected2)
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_graph_breaks(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x, l, r):
if l > r:
return x.sum()
mid = (l + r) // 2
if x.sum() == mid:
return x.sum()
elif x.sum() < mid:
return fn(x, l, mid)
else:
return fn(x, mid + 1, r)
def guard_filter_fn(guards):
return [
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
for guard in guards
]
# Saving
compiled_fn = torch._dynamo.optimize(
backend="inductor", guard_filter_fn=guard_filter_fn
)(fn)
N = 10
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
for args in args_list:
compiled_fn(*args)
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=8, expected_dynamo=1)
compiled_fn = torch._dynamo.optimize(
backend="inductor", guard_filter_fn=guard_filter_fn
)(fn)
with torch.compiler.set_stance("fail_on_recompile"):
for args in args_list:
self.assertEqual(compiled_fn(*args), args[0].sum())
# Should have same number of frames as on cold start
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_lazy_backward(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
return x.sin() + x.cos()
arg1 = torch.randn(3, 2, device=device, requires_grad=True)
arg2 = arg1.clone().detach_().requires_grad_(True)
compiled_fn = torch.compile(fn)
expected1 = compiled_fn(arg1)
expected1.sum().backward()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=1, expected_dynamo=1)
compiled_fn = torch.compile(fn)
# Run it again, no recompile needed
with torch.compiler.set_stance("fail_on_recompile"):
expected2 = compiled_fn(arg2)
expected2.sum().backward()
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_graph_break_partial_backend(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
y = x.sin()
torch._dynamo.graph_break()
return x.sin() + y
arg1 = torch.randn(3, 2, device=device, requires_grad=True)
arg2 = arg1.clone().detach_().requires_grad_(True)
compiled_fn = torch.compile(fn)
expected1 = compiled_fn(arg1)
expected1.sum().backward()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
# Remove backends related to resume functions
dynamo_entry = next(iter(PrecompileContext._dynamo_cache_entries.values()))
for code in dynamo_entry.codes:
module = sys.modules[code.python_module]
if code.install_to_global:
# Clear the fn_names from global scope, to simulate a new environment
for fn_name in code.function_names:
module.__dict__.pop(fn_name)
for fn_name in code.function_names:
if "resume" in fn_name:
self.assertEqual(len(code.backend_ids), 1)
# delete the fn from the global scope to simulate a new
backend = code.backend_ids[0]
# Delete the backend associated with the resume function
del PrecompileContext._backend_artifacts_by_key[backend]
self._save_and_reload(expected_backends=1, expected_dynamo=1)
compiled_fn = torch.compile(fn)
# Run it again. There will be a recompile because one of the backends is deleted, but it should
# still work.
expected2 = compiled_fn(arg2)
expected2.sum().backward()
self.assertEqual(expected1, expected2)
# One recompile on a new frame, so total_frames should increase by 1
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames + 1)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_call_function_from_resume(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
mod = torch.nn.Linear(2, 3, device=device)
def foo(x, mod):
pred = mod(x)
compute_loss_helper(pred).backward()
return None
args = (torch.randn(3, 2, device=device), mod)
compiled_fn = torch.compile(foo)
compiled_fn(*args)
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(expected_backends=1, expected_dynamo=1)
compiled_fn = torch.compile(foo)
# Run it again, no recompile needed
with torch.compiler.set_stance("fail_on_recompile"):
compiled_fn(*args)
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_code_with_generator(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def foo(set_of_x):
if not all(isinstance(s, torch.Tensor) for s in set_of_x):
raise TypeError(
f"Expected all elements of set_of_x to be tensors, got {set_of_x}"
)
return torch.cat(set_of_x, dim=0)
args = ([torch.randn(3, 2, device=device) for _ in range(3)],)
compiled_fn = torch.compile(foo)
compiled_fn(*args)
self._save_and_reload(expected_backends=1, expected_dynamo=1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()