mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Differential Revision: [D82735769](https://our.internmc.facebook.com/intern/diff/D82735769/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163521 Approved by: https://github.com/zhxchen17
635 lines
23 KiB
Python
635 lines
23 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import importlib
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo.testing
|
|
import torch._inductor.config
|
|
import torch._inductor.test_case
|
|
import torch.onnx.operators
|
|
import torch.utils.cpp_extension
|
|
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
|
|
from torch._dynamo.precompile_context import PrecompileContext
|
|
from torch._dynamo.testing import reduce_to_scalar_loss
|
|
from torch._functorch import config as functorch_config
|
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
HAS_CUDA_AND_TRITON,
|
|
HAS_XPU_AND_TRITON,
|
|
)
|
|
|
|
|
|
def compute_loss_helper(x):
|
|
return reduce_to_scalar_loss(x)
|
|
|
|
|
|
@functorch_config.patch("bundled_autograd_cache", True)
|
|
@torch._dynamo.config.patch({"strict_precompile": True})
|
|
@instantiate_parametrized_tests
|
|
class TestPackage(torch._inductor.test_case.TestCase):
|
|
def path(self):
|
|
path = os.path.join(cache_dir(), f"package_{self.id()}")
|
|
os.makedirs(path, exist_ok=True)
|
|
return path
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
torch._dynamo.reset()
|
|
torch._dynamo.utils.counters.clear()
|
|
DynamoCache.clear()
|
|
PrecompileContext.clear()
|
|
|
|
def _save_and_reload(self, expected_backends, expected_dynamo):
|
|
"""
|
|
Serializes all artifacts, clears all caches, then reloads the serialized artifact
|
|
Simulates a new process.
|
|
|
|
Args:
|
|
expected_backends: Expected number of precompile_aot_autograd_artifacts
|
|
expected_dynamo: Expected number of precompile_dynamo_artifacts
|
|
"""
|
|
debug_info = PrecompileContext.save_to_dynamo_cache()
|
|
self.assertEqual(len(debug_info["dynamo"]), expected_dynamo)
|
|
self.assertEqual(len(debug_info["backends"]), expected_backends)
|
|
torch._dynamo.reset()
|
|
PrecompileContext.clear()
|
|
|
|
@unittest.expectedFailure # FUNCTION_MATCH guard not serializable today
|
|
def test_nn_module(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10, device="cuda")
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
fn = MyModule()
|
|
package = CompilePackage(fn.forward)
|
|
compiled_fn = torch._dynamo.optimize("inductor", package=package)(fn)
|
|
x = torch.randn(10, 10, device="cuda")
|
|
compiled_fn(x)
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
def test_basic_fn(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
ctx = DiskDynamoStore()
|
|
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
args = (
|
|
torch.randn(
|
|
3,
|
|
2,
|
|
device=device,
|
|
),
|
|
)
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(backend, package=package)(fn)
|
|
expected = compiled_fn(*args)
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
|
|
ctx.save_package(package, self.path())
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args)
|
|
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
|
package.install(backends)
|
|
self.assertEqual(expected, compiled_fn(*args))
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
def test_lazy_backward(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
ctx = DiskDynamoStore()
|
|
|
|
def fn(x):
|
|
return x.sin() + x.cos()
|
|
|
|
args = (
|
|
torch.zeros(
|
|
3,
|
|
2,
|
|
device=device,
|
|
requires_grad=True,
|
|
),
|
|
)
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(backend, package=package)(fn)
|
|
expected = compiled_fn(*args)
|
|
expected.sum().backward()
|
|
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
|
|
ctx.save_package(package, self.path())
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args)
|
|
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
|
package.install(backends)
|
|
self.assertEqual(expected, compiled_fn(*args))
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
def test_graph_break_bomb(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
ctx = DiskDynamoStore()
|
|
|
|
def fn(x, l, r):
|
|
if l > r:
|
|
return x.sum()
|
|
mid = (l + r) // 2
|
|
if x.sum() == mid:
|
|
return x.sum()
|
|
elif x.sum() < mid:
|
|
return fn(x, l, mid)
|
|
else:
|
|
return fn(x, mid + 1, r)
|
|
|
|
def guard_filter_fn(guards):
|
|
return [
|
|
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
|
for guard in guards
|
|
]
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend=backend, package=package, guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
N = 10
|
|
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
|
|
for args in args_list:
|
|
compiled_fn(*args)
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
ctx.save_package(package, self.path())
|
|
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
for args in args_list:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args)
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend="eager", package=package, guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
package.install(backends)
|
|
for args in args_list:
|
|
self.assertEqual(compiled_fn(*args), args[0].sum())
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(torch.tensor(N), 0, N - 1)
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
def test_dynamic_shape(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
ctx = DiskDynamoStore()
|
|
|
|
def fn(x):
|
|
return x + x.shape[0]
|
|
|
|
args = (torch.randn(3, 2, device=device),)
|
|
args1 = (torch.randn(5, 2, device=device),)
|
|
args2 = (torch.randn(7, 2, device=device),)
|
|
expected1 = fn(*args1)
|
|
|
|
torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5)
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(backend=backend, package=package)(fn)
|
|
compiled_fn(*args)
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
ctx.save_package(package, self.path())
|
|
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args1)
|
|
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
|
package.install(backends)
|
|
|
|
self.assertEqual(expected1, compiled_fn(*args1))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args2)
|
|
|
|
def test_file_change(self):
|
|
ctx = DiskDynamoStore()
|
|
|
|
def import_from_path(module_name, file_path):
|
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
mock_module_add_original = """
|
|
def add(x, y):
|
|
return x + y
|
|
"""
|
|
|
|
mock_module_add_modified = """
|
|
def add(x, y):
|
|
return x - y
|
|
"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
mock_module_add_original_path = os.path.join(
|
|
tmp_dir, "mock_module_add_original.py"
|
|
)
|
|
mock_module_add_modified_path = os.path.join(
|
|
tmp_dir, "mock_module_add_modified.py"
|
|
)
|
|
with open(mock_module_add_original_path, "w") as f:
|
|
f.write(mock_module_add_original)
|
|
with open(mock_module_add_modified_path, "w") as f:
|
|
f.write(mock_module_add_modified)
|
|
|
|
module = import_from_path(
|
|
"torch.test_package_helper",
|
|
mock_module_add_original_path,
|
|
)
|
|
|
|
def fn(x):
|
|
return module.add(x, 1)
|
|
|
|
args = (torch.randn(3, 2),)
|
|
|
|
def guard_filter_fn(guards):
|
|
return [
|
|
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
|
for guard in guards
|
|
]
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend="eager", package=package, guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
compiled_fn(*args)
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
ctx.save_package(package, self.path())
|
|
|
|
module = import_from_path(
|
|
"torch.test_package_helper",
|
|
mock_module_add_modified_path,
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, "Source code changes detected"):
|
|
ctx.load_package(fn, self.path())
|
|
|
|
module = import_from_path(
|
|
"torch.test_package_helper",
|
|
mock_module_add_original_path,
|
|
)
|
|
ctx.load_package(fn, self.path())
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
def test_dynamo_cache_manual_load(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x):
|
|
return x.sin() + x.cos()
|
|
|
|
def fn2(x):
|
|
return x.cos() + x
|
|
|
|
package1 = CompilePackage(fn)
|
|
package2 = CompilePackage(fn2)
|
|
compiled_fn1 = torch._dynamo.optimize(backend="inductor", package=package1)(fn)
|
|
compiled_fn2 = torch._dynamo.optimize(backend="inductor", package=package2)(fn2)
|
|
arg1 = torch.randn(3, 2, device=device)
|
|
arg2 = torch.randn(5, 2, device=device)
|
|
expected = [compiled_fn1(arg1), compiled_fn2(arg2)]
|
|
|
|
DynamoCache.save(package1)
|
|
DynamoCache.save(package2)
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
self._save_and_reload(expected_backends=2, expected_dynamo=2)
|
|
|
|
# These should exist because of populate_caches
|
|
package1 = DynamoCache.load_and_install_package(fn)
|
|
package2 = DynamoCache.load_and_install_package(fn2)
|
|
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
result1 = compiled_fn1(arg1)
|
|
result2 = compiled_fn2(arg2)
|
|
self.assertEqual(expected, [result1, result2])
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_automatic_dynamo_serialize(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x):
|
|
return x.sin() + x.cos()
|
|
|
|
def fn2(x):
|
|
return x.cos() + x
|
|
|
|
arg1 = torch.randn(3, 2, device=device)
|
|
arg2 = torch.randn(5, 2, device=device)
|
|
expected = [fn(arg1), fn2(arg2)]
|
|
compiled_fn1 = torch.compile(fn)
|
|
compiled_fn2 = torch.compile(fn2)
|
|
result = [compiled_fn1(arg1), compiled_fn2(arg2)]
|
|
self.assertEqual(expected, result)
|
|
DynamoCache.clear()
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
|
|
self._save_and_reload(expected_backends=2, expected_dynamo=2)
|
|
|
|
compiled_fn1 = torch.compile(fn)
|
|
compiled_fn2 = torch.compile(fn2)
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
result1 = compiled_fn1(arg1)
|
|
result2 = compiled_fn2(arg2)
|
|
self.assertEqual(expected, [result1, result2])
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_automatic_dynamo_recompiles(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x):
|
|
return x.sin() + x.cos()
|
|
|
|
arg1 = torch.randn(3, 2, device=device)
|
|
arg2 = torch.randn(5, 2, device=device)
|
|
compiled_fn = torch.compile(fn)
|
|
expected1 = compiled_fn(arg1)
|
|
|
|
# Should cause a recompile
|
|
expected2 = compiled_fn(arg2)
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
|
|
self._save_and_reload(expected_backends=2, expected_dynamo=1)
|
|
|
|
compiled_fn = torch.compile(fn)
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
result1 = compiled_fn(arg1)
|
|
result2 = compiled_fn(arg2)
|
|
# Because of automatic dynamic, a third random shape should also not cause a recompile
|
|
arg3 = torch.randn(7, 2, device=device)
|
|
compiled_fn(arg3)
|
|
self.assertEqual(result1, expected1)
|
|
self.assertEqual(result2, expected2)
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_automatic_dynamo_graph_breaks(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x, l, r):
|
|
if l > r:
|
|
return x.sum()
|
|
mid = (l + r) // 2
|
|
if x.sum() == mid:
|
|
return x.sum()
|
|
elif x.sum() < mid:
|
|
return fn(x, l, mid)
|
|
else:
|
|
return fn(x, mid + 1, r)
|
|
|
|
def guard_filter_fn(guards):
|
|
return [
|
|
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
|
for guard in guards
|
|
]
|
|
|
|
# Saving
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend="inductor", guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
N = 10
|
|
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
|
|
for args in args_list:
|
|
compiled_fn(*args)
|
|
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
self._save_and_reload(expected_backends=8, expected_dynamo=1)
|
|
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend="inductor", guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
for args in args_list:
|
|
self.assertEqual(compiled_fn(*args), args[0].sum())
|
|
# Should have same number of frames as on cold start
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_automatic_dynamo_lazy_backward(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x):
|
|
return x.sin() + x.cos()
|
|
|
|
arg1 = torch.randn(3, 2, device=device, requires_grad=True)
|
|
arg2 = arg1.clone().detach_().requires_grad_(True)
|
|
|
|
compiled_fn = torch.compile(fn)
|
|
expected1 = compiled_fn(arg1)
|
|
expected1.sum().backward()
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
|
|
self._save_and_reload(expected_backends=1, expected_dynamo=1)
|
|
|
|
compiled_fn = torch.compile(fn)
|
|
# Run it again, no recompile needed
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
expected2 = compiled_fn(arg2)
|
|
expected2.sum().backward()
|
|
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_graph_break_partial_backend(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x):
|
|
y = x.sin()
|
|
torch._dynamo.graph_break()
|
|
return x.sin() + y
|
|
|
|
arg1 = torch.randn(3, 2, device=device, requires_grad=True)
|
|
arg2 = arg1.clone().detach_().requires_grad_(True)
|
|
compiled_fn = torch.compile(fn)
|
|
expected1 = compiled_fn(arg1)
|
|
expected1.sum().backward()
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
|
|
# Remove backends related to resume functions
|
|
dynamo_entry = next(iter(PrecompileContext._dynamo_cache_entries.values()))
|
|
for code in dynamo_entry.codes:
|
|
module = sys.modules[code.python_module]
|
|
if code.install_to_global:
|
|
# Clear the fn_names from global scope, to simulate a new environment
|
|
for fn_name in code.function_names:
|
|
module.__dict__.pop(fn_name)
|
|
for fn_name in code.function_names:
|
|
if "resume" in fn_name:
|
|
self.assertEqual(len(code.backend_ids), 1)
|
|
# delete the fn from the global scope to simulate a new
|
|
backend = code.backend_ids[0]
|
|
# Delete the backend associated with the resume function
|
|
del PrecompileContext._backend_artifacts_by_key[backend]
|
|
|
|
self._save_and_reload(expected_backends=1, expected_dynamo=1)
|
|
|
|
compiled_fn = torch.compile(fn)
|
|
# Run it again. There will be a recompile because one of the backends is deleted, but it should
|
|
# still work.
|
|
expected2 = compiled_fn(arg2)
|
|
expected2.sum().backward()
|
|
self.assertEqual(expected1, expected2)
|
|
# One recompile on a new frame, so total_frames should increase by 1
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames + 1)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_call_function_from_resume(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
mod = torch.nn.Linear(2, 3, device=device)
|
|
|
|
def foo(x, mod):
|
|
pred = mod(x)
|
|
compute_loss_helper(pred).backward()
|
|
return None
|
|
|
|
args = (torch.randn(3, 2, device=device), mod)
|
|
compiled_fn = torch.compile(foo)
|
|
compiled_fn(*args)
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
|
|
self._save_and_reload(expected_backends=1, expected_dynamo=1)
|
|
|
|
compiled_fn = torch.compile(foo)
|
|
# Run it again, no recompile needed
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
compiled_fn(*args)
|
|
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_code_with_generator(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def foo(set_of_x):
|
|
if not all(isinstance(s, torch.Tensor) for s in set_of_x):
|
|
raise TypeError(
|
|
f"Expected all elements of set_of_x to be tensors, got {set_of_x}"
|
|
)
|
|
|
|
return torch.cat(set_of_x, dim=0)
|
|
|
|
args = ([torch.randn(3, 2, device=device) for _ in range(3)],)
|
|
compiled_fn = torch.compile(foo)
|
|
compiled_fn(*args)
|
|
self._save_and_reload(expected_backends=1, expected_dynamo=1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|