mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: When compiled code has generator, code.co_firstlineno will be inconsistent with the result from inspect.getsource, which returns the toplevel enclosing code source rather than the inner code location. In this case, it seems simpler to just use the toplevel enclosing code location rather than the co_firstlineno field. Test Plan: test_package.py -k test_code_with_generator Rollback Plan: Differential Revision: D81929751 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162389 Approved by: https://github.com/dolpm, https://github.com/hrithick-codes
640 lines
23 KiB
Python
640 lines
23 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import importlib
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo.testing
|
|
import torch._inductor.config
|
|
import torch._inductor.test_case
|
|
import torch.onnx.operators
|
|
import torch.utils.cpp_extension
|
|
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
|
|
from torch._dynamo.precompile_context import PrecompileContext
|
|
from torch._dynamo.testing import reduce_to_scalar_loss
|
|
from torch._functorch import config as functorch_config
|
|
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
|
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
skipIfRocm,
|
|
skipIfXpu,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
HAS_CUDA_AND_TRITON,
|
|
HAS_XPU_AND_TRITON,
|
|
)
|
|
|
|
|
|
def compute_loss_helper(x):
|
|
return reduce_to_scalar_loss(x)
|
|
|
|
|
|
@functorch_config.patch("bundled_autograd_cache", True)
|
|
@torch._dynamo.config.patch({"strict_precompile": True})
|
|
@instantiate_parametrized_tests
|
|
class TestPackage(torch._inductor.test_case.TestCase):
|
|
def path(self):
|
|
path = os.path.join(cache_dir(), f"package_{self.id()}")
|
|
os.makedirs(path, exist_ok=True)
|
|
return path
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
torch._dynamo.reset()
|
|
torch._dynamo.utils.counters.clear()
|
|
DynamoCache.clear()
|
|
PrecompileContext.clear()
|
|
|
|
def _save_and_reload(
|
|
self, expected_backends, expected_dynamo, expected_autotune=None
|
|
):
|
|
"""
|
|
Serializes all artifacts, clears all caches, then reloads the serialized artifact
|
|
Simulates a new process.
|
|
|
|
Args:
|
|
expected_backends: Expected number of precompile_aot_autograd_artifacts
|
|
expected_dynamo: Expected number of precompile_dynamo_artifacts
|
|
"""
|
|
serialized = PrecompileContext.serialize()
|
|
assert serialized is not None
|
|
(bytes_, cache_info) = serialized
|
|
self.assertEqual(
|
|
len(cache_info.precompile_aot_autograd_artifacts), expected_backends
|
|
)
|
|
self.assertEqual(len(cache_info.precompile_dynamo_artifacts), expected_dynamo)
|
|
if expected_autotune is not None:
|
|
self.assertEqual(len(cache_info.autotune_artifacts), expected_autotune)
|
|
|
|
torch._dynamo.reset()
|
|
DynamoCache.clear()
|
|
PrecompileContext.clear()
|
|
|
|
deserialized = PrecompileContext.deserialize(bytes_)
|
|
assert deserialized is not None
|
|
PrecompileContext.populate_caches(deserialized)
|
|
|
|
@unittest.expectedFailure # FUNCTION_MATCH guard not serializable today
|
|
def test_nn_module(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10, device="cuda")
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
fn = MyModule()
|
|
package = CompilePackage(fn.forward)
|
|
compiled_fn = torch._dynamo.optimize("inductor", package=package)(fn)
|
|
x = torch.randn(10, 10, device="cuda")
|
|
compiled_fn(x)
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
def test_basic_fn(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
ctx = DiskDynamoStore()
|
|
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
args = (
|
|
torch.randn(
|
|
3,
|
|
2,
|
|
device=device,
|
|
),
|
|
)
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(backend, package=package)(fn)
|
|
expected = compiled_fn(*args)
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
|
|
ctx.save_package(package, self.path())
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args)
|
|
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
|
package.install(backends)
|
|
self.assertEqual(expected, compiled_fn(*args))
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
def test_lazy_backward(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
ctx = DiskDynamoStore()
|
|
|
|
def fn(x):
|
|
return x.sin() + x.cos()
|
|
|
|
args = (
|
|
torch.zeros(
|
|
3,
|
|
2,
|
|
device=device,
|
|
requires_grad=True,
|
|
),
|
|
)
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(backend, package=package)(fn)
|
|
expected = compiled_fn(*args)
|
|
expected.sum().backward()
|
|
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
|
|
ctx.save_package(package, self.path())
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args)
|
|
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
|
package.install(backends)
|
|
self.assertEqual(expected, compiled_fn(*args))
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
def test_graph_break_bomb(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
ctx = DiskDynamoStore()
|
|
|
|
def fn(x, l, r):
|
|
if l > r:
|
|
return x.sum()
|
|
mid = (l + r) // 2
|
|
if x.sum() == mid:
|
|
return x.sum()
|
|
elif x.sum() < mid:
|
|
return fn(x, l, mid)
|
|
else:
|
|
return fn(x, mid + 1, r)
|
|
|
|
def guard_filter_fn(guards):
|
|
return [
|
|
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
|
for guard in guards
|
|
]
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend=backend, package=package, guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
N = 10
|
|
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
|
|
for args in args_list:
|
|
compiled_fn(*args)
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
ctx.save_package(package, self.path())
|
|
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
for args in args_list:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args)
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend="eager", package=package, guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
package.install(backends)
|
|
for args in args_list:
|
|
self.assertEqual(compiled_fn(*args), args[0].sum())
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(torch.tensor(N), 0, N - 1)
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
def test_dynamic_shape(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
ctx = DiskDynamoStore()
|
|
|
|
def fn(x):
|
|
return x + x.shape[0]
|
|
|
|
args = (torch.randn(3, 2, device=device),)
|
|
args1 = (torch.randn(5, 2, device=device),)
|
|
args2 = (torch.randn(7, 2, device=device),)
|
|
expected1 = fn(*args1)
|
|
|
|
torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5)
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(backend=backend, package=package)(fn)
|
|
compiled_fn(*args)
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
ctx.save_package(package, self.path())
|
|
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args1)
|
|
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
|
package.install(backends)
|
|
|
|
self.assertEqual(expected1, compiled_fn(*args1))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args2)
|
|
|
|
def test_file_change(self):
|
|
ctx = DiskDynamoStore()
|
|
|
|
def import_from_path(module_name, file_path):
|
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
mock_module_add_original = """
|
|
def add(x, y):
|
|
return x + y
|
|
"""
|
|
|
|
mock_module_add_modified = """
|
|
def add(x, y):
|
|
return x - y
|
|
"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
mock_module_add_original_path = os.path.join(
|
|
tmp_dir, "mock_module_add_original.py"
|
|
)
|
|
mock_module_add_modified_path = os.path.join(
|
|
tmp_dir, "mock_module_add_modified.py"
|
|
)
|
|
with open(mock_module_add_original_path, "w") as f:
|
|
f.write(mock_module_add_original)
|
|
with open(mock_module_add_modified_path, "w") as f:
|
|
f.write(mock_module_add_modified)
|
|
|
|
module = import_from_path(
|
|
"torch.test_package_helper",
|
|
mock_module_add_original_path,
|
|
)
|
|
|
|
def fn(x):
|
|
return module.add(x, 1)
|
|
|
|
args = (torch.randn(3, 2),)
|
|
|
|
def guard_filter_fn(guards):
|
|
return [
|
|
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
|
for guard in guards
|
|
]
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend="eager", package=package, guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
compiled_fn(*args)
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
ctx.save_package(package, self.path())
|
|
|
|
module = import_from_path(
|
|
"torch.test_package_helper",
|
|
mock_module_add_modified_path,
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, "Source code changes detected"):
|
|
ctx.load_package(fn, self.path())
|
|
|
|
module = import_from_path(
|
|
"torch.test_package_helper",
|
|
mock_module_add_original_path,
|
|
)
|
|
ctx.load_package(fn, self.path())
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
def test_dynamo_cache_manual_load(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x):
|
|
return x.sin() + x.cos()
|
|
|
|
def fn2(x):
|
|
return x.cos() + x
|
|
|
|
package1 = CompilePackage(fn)
|
|
package2 = CompilePackage(fn2)
|
|
compiled_fn1 = torch._dynamo.optimize(backend="inductor", package=package1)(fn)
|
|
compiled_fn2 = torch._dynamo.optimize(backend="inductor", package=package2)(fn2)
|
|
arg1 = torch.randn(3, 2, device=device)
|
|
arg2 = torch.randn(5, 2, device=device)
|
|
expected = [compiled_fn1(arg1), compiled_fn2(arg2)]
|
|
|
|
DynamoCache.save(package1)
|
|
DynamoCache.save(package2)
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
self._save_and_reload(expected_backends=2, expected_dynamo=2)
|
|
|
|
# These should exist because of populate_caches
|
|
package1 = DynamoCache.load_and_install_package(fn)
|
|
package2 = DynamoCache.load_and_install_package(fn2)
|
|
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
result1 = compiled_fn1(arg1)
|
|
result2 = compiled_fn2(arg2)
|
|
self.assertEqual(expected, [result1, result2])
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_automatic_dynamo_serialize(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x):
|
|
return x.sin() + x.cos()
|
|
|
|
def fn2(x):
|
|
return x.cos() + x
|
|
|
|
arg1 = torch.randn(3, 2, device=device)
|
|
arg2 = torch.randn(5, 2, device=device)
|
|
expected = [fn(arg1), fn2(arg2)]
|
|
compiled_fn1 = torch.compile(fn)
|
|
compiled_fn2 = torch.compile(fn2)
|
|
result = [compiled_fn1(arg1), compiled_fn2(arg2)]
|
|
self.assertEqual(expected, result)
|
|
DynamoCache.clear()
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
|
|
self._save_and_reload(expected_backends=2, expected_dynamo=2)
|
|
|
|
compiled_fn1 = torch.compile(fn)
|
|
compiled_fn2 = torch.compile(fn2)
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
result1 = compiled_fn1(arg1)
|
|
result2 = compiled_fn2(arg2)
|
|
self.assertEqual(expected, [result1, result2])
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
@skipIfXpu
|
|
@skipIfRocm
|
|
def test_automatic_dynamo_autotune_cache(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x, y):
|
|
return x.sin() + y
|
|
|
|
arg1 = torch.randn(3, 3, device=device)
|
|
arg2 = torch.randn(3, 3, device=device)
|
|
expected = fn(arg1, arg2).clone()
|
|
|
|
with PatchCaches():
|
|
compiled_fn1 = torch.compile(fn, mode="max-autotune")
|
|
result = compiled_fn1(arg1, arg2).clone()
|
|
self.assertEqual(expected, result)
|
|
self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
|
|
DynamoCache.clear()
|
|
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
self._save_and_reload(
|
|
expected_backends=1, expected_dynamo=1, expected_autotune=1
|
|
)
|
|
compiled_fn1 = torch.compile(fn, mode="max-autotune")
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
result1 = compiled_fn1(arg1, arg2).clone()
|
|
self.assertEqual(expected, result1)
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_automatic_dynamo_recompiles(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x):
|
|
return x.sin() + x.cos()
|
|
|
|
arg1 = torch.randn(3, 2, device=device)
|
|
arg2 = torch.randn(5, 2, device=device)
|
|
compiled_fn = torch.compile(fn)
|
|
expected1 = compiled_fn(arg1)
|
|
|
|
# Should cause a recompile
|
|
expected2 = compiled_fn(arg2)
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
|
|
self._save_and_reload(expected_backends=2, expected_dynamo=1)
|
|
|
|
compiled_fn = torch.compile(fn)
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
result1 = compiled_fn(arg1)
|
|
result2 = compiled_fn(arg2)
|
|
# Because of automatic dynamic, a third random shape should also not cause a recompile
|
|
arg3 = torch.randn(7, 2, device=device)
|
|
compiled_fn(arg3)
|
|
self.assertEqual(result1, expected1)
|
|
self.assertEqual(result2, expected2)
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_automatic_dynamo_graph_breaks(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x, l, r):
|
|
if l > r:
|
|
return x.sum()
|
|
mid = (l + r) // 2
|
|
if x.sum() == mid:
|
|
return x.sum()
|
|
elif x.sum() < mid:
|
|
return fn(x, l, mid)
|
|
else:
|
|
return fn(x, mid + 1, r)
|
|
|
|
def guard_filter_fn(guards):
|
|
return [
|
|
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
|
for guard in guards
|
|
]
|
|
|
|
# Saving
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend="inductor", guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
N = 10
|
|
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
|
|
for args in args_list:
|
|
compiled_fn(*args)
|
|
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
self._save_and_reload(expected_backends=8, expected_dynamo=1)
|
|
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend="inductor", guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
for args in args_list:
|
|
self.assertEqual(compiled_fn(*args), args[0].sum())
|
|
# Should have same number of frames as on cold start
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_automatic_dynamo_lazy_backward(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def fn(x):
|
|
return x.sin() + x.cos()
|
|
|
|
arg1 = torch.randn(3, 2, device=device, requires_grad=True)
|
|
arg2 = arg1.clone().detach_().requires_grad_(True)
|
|
|
|
compiled_fn = torch.compile(fn)
|
|
expected1 = compiled_fn(arg1)
|
|
expected1.sum().backward()
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
|
|
self._save_and_reload(expected_backends=1, expected_dynamo=1)
|
|
|
|
compiled_fn = torch.compile(fn)
|
|
# Run it again, no recompile needed
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
expected2 = compiled_fn(arg2)
|
|
expected2.sum().backward()
|
|
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_call_function_from_resume(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
mod = torch.nn.Linear(2, 3, device=device)
|
|
|
|
def foo(x, mod):
|
|
pred = mod(x)
|
|
compute_loss_helper(pred).backward()
|
|
return None
|
|
|
|
args = (torch.randn(3, 2, device=device), mod)
|
|
compiled_fn = torch.compile(foo)
|
|
compiled_fn(*args)
|
|
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
|
|
|
self._save_and_reload(expected_backends=1, expected_dynamo=1)
|
|
|
|
compiled_fn = torch.compile(foo)
|
|
# Run it again, no recompile needed
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
compiled_fn(*args)
|
|
|
|
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
|
|
|
@parametrize("device", ("cpu", "cuda", "xpu"))
|
|
@torch._dynamo.config.patch(caching_precompile=True)
|
|
def test_code_with_generator(self, device):
|
|
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
|
raise unittest.SkipTest("Requires XPU/Triton")
|
|
|
|
def foo(set_of_x):
|
|
if not all(isinstance(s, torch.Tensor) for s in set_of_x):
|
|
raise TypeError(
|
|
f"Expected all elements of set_of_x to be tensors, got {set_of_x}"
|
|
)
|
|
|
|
return torch.cat(set_of_x, dim=0)
|
|
|
|
args = ([torch.randn(3, 2, device=device) for _ in range(3)],)
|
|
compiled_fn = torch.compile(foo)
|
|
compiled_fn(*args)
|
|
self._save_and_reload(expected_backends=1, expected_dynamo=1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|