mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #162624 Fixes #162586 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163412 Approved by: https://github.com/eellison ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419, #163434, #163393
798 lines
24 KiB
Python
798 lines
24 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import operator
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._C import (
|
|
_len_torch_function_stack,
|
|
_pop_torch_function_stack,
|
|
_push_on_torch_function_stack,
|
|
)
|
|
from torch._dynamo.utils import counters
|
|
from torch.overrides import (
|
|
_get_current_function_mode_stack,
|
|
BaseTorchFunctionMode,
|
|
TorchFunctionMode,
|
|
)
|
|
from torch.testing._internal.common_utils import skipIfXpu
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE
|
|
from torch.testing._internal.triton_utils import requires_gpu
|
|
from torch.utils._device import DeviceContext
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
|
|
device_type = (
|
|
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
|
)
|
|
|
|
|
|
class TestMode(BaseTorchFunctionMode):
|
|
def __torch_function__(self, func, types, args, kwargs=None):
|
|
if not kwargs:
|
|
kwargs = {}
|
|
|
|
if func == torch.add:
|
|
return torch.zeros(2, 2)
|
|
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
|
|
class HopDetectionError(Exception):
|
|
pass
|
|
|
|
|
|
class TestModeRaises(BaseTorchFunctionMode):
|
|
def __torch_function__(self, func, types, args, kwargs=None):
|
|
if not kwargs:
|
|
kwargs = {}
|
|
|
|
import torch._higher_order_ops
|
|
|
|
if func == torch._higher_order_ops.flex_attention:
|
|
raise HopDetectionError("test")
|
|
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
|
|
class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super().tearDownClass()
|
|
|
|
def test_torch_dispatch_ignore_compile_internals(self):
|
|
counters.clear()
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
@torch.library.custom_op("mylib::foo", mutates_args=())
|
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
|
return x.clone()
|
|
|
|
def checksum(x):
|
|
return x.abs().sum()
|
|
|
|
_checksums = []
|
|
|
|
class ChecksumFoo(TorchDispatchMode):
|
|
@classmethod
|
|
def ignore_compile_internals(cls):
|
|
return True
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def __torch_dispatch__(self, func, types, args, kwargs=None):
|
|
kwargs = kwargs or {}
|
|
|
|
if func is torch.ops.mylib.foo.default:
|
|
# Do some compute, smoketest to see if there's a bad interaction
|
|
_checksums.append(args[0].abs().sum())
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
# test e2e, with Inductor, as smoketest.
|
|
@torch._dynamo.error_on_graph_break(True)
|
|
@torch.compile(backend="inductor")
|
|
def g(x):
|
|
return 2 * x.sin().cos()
|
|
|
|
x = torch.randn(3)
|
|
|
|
with ChecksumFoo():
|
|
foo(x)
|
|
g(x)
|
|
foo(x)
|
|
|
|
self.assertEqual(len(_checksums), 2)
|
|
# The correct result here is 1: Dynamo should capture the `g` frame.
|
|
self.assertEqual(counters["frames"]["total"], 1)
|
|
self.assertEqual(counters["frames"]["ok"], 1)
|
|
|
|
def test_skip_torch_dispatch_modes(self):
|
|
class RewriteAddToMul(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if func is torch.ops.aten.add.Tensor:
|
|
func = torch.ops.aten.mul.Tensor
|
|
return func(*args, **kwargs)
|
|
|
|
def fn(x):
|
|
return x + x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
x = torch.tensor([3.0])
|
|
with RewriteAddToMul():
|
|
eager_res = fn(x)
|
|
compiled_res = torch.compile(fn, backend=cnt)(x)
|
|
|
|
self.assertEqual(eager_res, compiled_res)
|
|
self.assertEqual(cnt.frame_count, 0)
|
|
|
|
|
|
class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.default_device_old = torch.get_default_device()
|
|
super().setUpClass()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
torch.set_default_device(cls.default_device_old)
|
|
super().tearDownClass()
|
|
|
|
def setUp(self):
|
|
torch.set_default_device(None)
|
|
torch._dynamo.reset()
|
|
|
|
def tearDown(self):
|
|
torch.set_default_device(None)
|
|
torch._dynamo.reset()
|
|
|
|
def _run_torch_function_mode_guard_test(self):
|
|
class TestMode1(BaseTorchFunctionMode):
|
|
pass
|
|
|
|
class TestMode2(BaseTorchFunctionMode):
|
|
pass
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt.__call__)
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
inp = torch.ones(2, 2)
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
with TestMode1():
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
with TestMode1(), TestMode2():
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
with TestMode2(), TestMode1():
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
with TestMode1():
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
|
|
def test_torch_function_mode_guards_py(self):
|
|
self._run_torch_function_mode_guard_test()
|
|
|
|
def test_torch_function_mode_guards_cpp(self):
|
|
self._run_torch_function_mode_guard_test()
|
|
|
|
@requires_gpu
|
|
def test_torch_function_mode_preserves_cuda_rng_state(self):
|
|
class ConstantReturnMode(TorchFunctionMode):
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
return -42
|
|
|
|
@torch._dynamo.optimize("eager")
|
|
def fn():
|
|
with ConstantReturnMode():
|
|
return 123
|
|
|
|
self.assertEqual(fn(), 123)
|
|
|
|
def test_stack_state_mutation_default_device(self):
|
|
m = BaseTorchFunctionMode()
|
|
m1 = BaseTorchFunctionMode()
|
|
with m, m1:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
torch.set_default_device("cpu")
|
|
_pop_torch_function_stack()
|
|
|
|
fn(torch.ones(2, 2))
|
|
_push_on_torch_function_stack(m1)
|
|
|
|
stack = _get_current_function_mode_stack()
|
|
self.assertIsInstance(stack[0], DeviceContext)
|
|
self.assertEqual(stack[0].device, torch.device("cpu"))
|
|
self.assertIs(stack[1], m)
|
|
self.assertIs(stack[2], m1)
|
|
|
|
def test_stack_state_clear_default_device(self):
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
torch.set_default_device(None)
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2))
|
|
stack = _get_current_function_mode_stack()
|
|
self.assertEqual(len(stack), 0)
|
|
|
|
m = BaseTorchFunctionMode()
|
|
m1 = BaseTorchFunctionMode()
|
|
|
|
# Stack populated, add device
|
|
with m, m1:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
torch.set_default_device("cpu")
|
|
torch.set_default_device(None)
|
|
torch.set_default_device("cpu")
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2))
|
|
stack = _get_current_function_mode_stack()
|
|
self.assertEqual(stack[0].device, torch.device("cpu"))
|
|
self.assertIs(stack[1], m)
|
|
self.assertIs(stack[2], m1)
|
|
|
|
# Stack populated, remove device
|
|
torch.set_default_device("cpu")
|
|
with m, m1:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
torch.set_default_device(None)
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2))
|
|
stack = _get_current_function_mode_stack()
|
|
self.assertIs(stack[0], m)
|
|
self.assertIs(stack[1], m1)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
torch.set_default_device("cpu")
|
|
torch.set_default_device("cpu")
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2))
|
|
stack = _get_current_function_mode_stack()
|
|
self.assertEqual(stack[0].device, torch.device("cpu"))
|
|
torch.set_default_device(None)
|
|
|
|
def test_pop_torch_function_mode(self):
|
|
m = BaseTorchFunctionMode()
|
|
with m:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
_pop_torch_function_stack()
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2))
|
|
|
|
self.assertEqual(_len_torch_function_stack(), 0)
|
|
# reset stack so __exit__ doesn't crash
|
|
_push_on_torch_function_stack(m)
|
|
|
|
self.assertEqual(_len_torch_function_stack(), 0)
|
|
|
|
def test_is_torch_function_all_disabled(self):
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
return (
|
|
torch._C._is_torch_function_all_disabled(),
|
|
torch.add(x, 1.0),
|
|
)
|
|
|
|
input = torch.ones(2, 2)
|
|
res, _ = fn(input)
|
|
self.assertFalse(res)
|
|
|
|
def test_error_empty_stack_pop_torch_function_mode(self):
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
_pop_torch_function_stack()
|
|
return x + 1
|
|
|
|
self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Attempted to pop from empty torch function mode stack",
|
|
lambda: fn(torch.ones(2, 2)),
|
|
)
|
|
|
|
def test_push_torch_function_mode(self):
|
|
m = BaseTorchFunctionMode()
|
|
with m:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x, m):
|
|
_push_on_torch_function_stack(m)
|
|
return x + 1
|
|
|
|
fn(torch.ones(2, 2), m)
|
|
|
|
self.assertEqual(_len_torch_function_stack(), 2)
|
|
# reset stack state
|
|
_pop_torch_function_stack()
|
|
|
|
self.assertEqual(_len_torch_function_stack(), 0)
|
|
|
|
def test_len_torch_function_mode(self):
|
|
m = BaseTorchFunctionMode()
|
|
with m:
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
z = _len_torch_function_stack()
|
|
return x + z
|
|
|
|
res = fn(torch.ones(2, 2))
|
|
self.assertEqual(res, torch.ones(2, 2) + 1)
|
|
self.assertEqual(_len_torch_function_stack(), 1)
|
|
|
|
def test_intermedate_torch_function_mode_construction_mutation(self):
|
|
class TestMode(BaseTorchFunctionMode):
|
|
def __init__(self, x):
|
|
self.x = x
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
z = TestMode(2)
|
|
z.y = 2
|
|
return x + 1, z
|
|
|
|
fn(torch.ones(2, 2))
|
|
|
|
def test_torch_function_mode_enabled_guard(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
inp = torch.ones(2, 2)
|
|
|
|
@torch.compile(backend=cnt.__call__)
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
with BaseTorchFunctionMode(), torch._C.DisableTorchFunctionSubclass():
|
|
with torch._C.DisableTorchFunction():
|
|
fn(inp)
|
|
fn(inp)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_nested_torch_function_mode(self):
|
|
mode_1_called = False
|
|
mode_2_called = False
|
|
|
|
def reset_state():
|
|
nonlocal mode_1_called
|
|
nonlocal mode_2_called
|
|
mode_1_called = False
|
|
mode_2_called = False
|
|
|
|
ones = torch.ones(2, 2)
|
|
zeros = torch.zeros(2, 2)
|
|
|
|
class TestMode1(BaseTorchFunctionMode):
|
|
def __torch_function__(self, func, types, args, kwargs=None):
|
|
if not kwargs:
|
|
kwargs = {}
|
|
|
|
nonlocal mode_1_called
|
|
|
|
mode_1_called = True
|
|
|
|
if func == torch.add:
|
|
return zeros
|
|
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
class TestMode2(BaseTorchFunctionMode):
|
|
def __torch_function__(self, func, types, args, kwargs=None):
|
|
if not kwargs:
|
|
kwargs = {}
|
|
|
|
nonlocal mode_2_called
|
|
|
|
mode_2_called = True
|
|
|
|
if func == torch.mul:
|
|
return ones
|
|
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
def fn(x):
|
|
return torch.add(x, 3)
|
|
|
|
def fn_2(x):
|
|
return torch.mul(x, 3) + torch.add(x, 3)
|
|
|
|
inp = torch.ones(2, 2) + 1
|
|
|
|
for fn_i in [fn, fn_2]:
|
|
fn_opt = torch.compile(fn_i, fullgraph=True)
|
|
with TestMode1(), TestMode2():
|
|
expected = fn_i(inp), mode_1_called, mode_2_called
|
|
reset_state()
|
|
actual = fn_opt(inp), mode_1_called, mode_2_called
|
|
reset_state()
|
|
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_torch_function_mode_disable(self):
|
|
class TestSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args, kwargs=None):
|
|
if not kwargs:
|
|
kwargs = {}
|
|
if func == torch.add:
|
|
return torch.ones(2, 2)
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
class TestMode(BaseTorchFunctionMode):
|
|
def __torch_function__(self, func, types, args, kwargs=None):
|
|
if not kwargs:
|
|
kwargs = {}
|
|
|
|
if func == torch.add:
|
|
return torch.zeros(2, 2)
|
|
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
def fn(x):
|
|
return torch.add(x, 3)
|
|
|
|
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
|
|
|
|
fn_opt = torch.compile(fn, fullgraph=True)
|
|
with TestMode():
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
expected = fn(inp)
|
|
actual = fn_opt(inp)
|
|
|
|
self.assertEqual(expected, actual)
|
|
|
|
with torch._C.DisableTorchFunction():
|
|
expected = fn(inp)
|
|
actual = fn_opt(inp)
|
|
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_torch_function_mode_highest_priority(self):
|
|
class TestSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args, kwargs=None):
|
|
if not kwargs:
|
|
kwargs = {}
|
|
if func == torch.add:
|
|
return torch.ones(2, 2)
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
def fn(x):
|
|
return torch.add(x, 3)
|
|
|
|
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
|
|
|
|
fn_opt = torch.compile(fn, fullgraph=True)
|
|
with TestMode():
|
|
expected = fn(inp)
|
|
actual = fn_opt(inp)
|
|
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_torch_function_mode_enter_exit(self):
|
|
def fn(x, y):
|
|
with TestMode():
|
|
o = torch.add(x, 3)
|
|
|
|
return torch.add(o, y)
|
|
|
|
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
|
fn_opt = torch.compile(fn, fullgraph=True)
|
|
|
|
expected = fn(*inp)
|
|
actual = fn_opt(*inp)
|
|
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_torch_function_mode_graph_break(self):
|
|
def fn(x, y):
|
|
with TestMode():
|
|
torch._dynamo.graph_break()
|
|
o = torch.add(x, 3)
|
|
|
|
return torch.add(o, y)
|
|
|
|
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
|
fn_opt = torch.compile(fn)
|
|
|
|
expected = fn(*inp)
|
|
actual = fn_opt(*inp)
|
|
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_torch_function_mode_and_pop_graph_break(self):
|
|
def fn(x, y):
|
|
with TestMode():
|
|
z = _pop_torch_function_stack()
|
|
torch._dynamo.graph_break()
|
|
_push_on_torch_function_stack(z)
|
|
o = torch.add(x, 3)
|
|
|
|
return torch.add(o, y)
|
|
|
|
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
|
fn_opt = torch.compile(fn)
|
|
|
|
expected = fn(*inp)
|
|
actual = fn_opt(*inp)
|
|
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_torch_function_mode_restore_on_exc(self):
|
|
@torch._dynamo.disable()
|
|
def err():
|
|
raise RuntimeError("test")
|
|
|
|
@torch.compile()
|
|
def fn(x):
|
|
with TestMode():
|
|
x += 1
|
|
err()
|
|
x += 2
|
|
return x
|
|
|
|
try:
|
|
fn(torch.ones(2, 2))
|
|
except RuntimeError:
|
|
pass
|
|
self.assertEqual(_len_torch_function_stack(), 0)
|
|
|
|
def test_torch_function_mode_and_pop_graph_break_mutation(self):
|
|
def fn(x, y):
|
|
with TestMode():
|
|
z = _pop_torch_function_stack()
|
|
z.y = 5
|
|
torch._dynamo.graph_break()
|
|
_push_on_torch_function_stack(z)
|
|
o = torch.add(x, 3)
|
|
o = torch.mul(o, z.y)
|
|
|
|
return torch.add(o, y)
|
|
|
|
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
|
|
fn_opt = torch.compile(fn)
|
|
|
|
expected = fn(*inp)
|
|
actual = fn_opt(*inp)
|
|
|
|
self.assertEqual(expected, actual)
|
|
|
|
# Needs larger cache size since we recompile for each op
|
|
@patch.object(torch._dynamo.config, "recompile_limit", 48)
|
|
def test_builtin_equivalent_funcs(self):
|
|
from torch._dynamo.variables.builtin import (
|
|
BUILTIN_TO_TENSOR_FN_MAP,
|
|
BUILTIN_TO_TENSOR_RFN_MAP,
|
|
)
|
|
from torch._dynamo.variables.torch_function import (
|
|
bin_int_ops,
|
|
bin_ops,
|
|
tensor_and_int_ops,
|
|
un_int_ops,
|
|
un_ops,
|
|
)
|
|
|
|
expected_func = None
|
|
valid = False
|
|
|
|
class FuncEquivMode(BaseTorchFunctionMode):
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
nonlocal expected_func
|
|
nonlocal valid
|
|
if not kwargs:
|
|
kwargs = {}
|
|
if torch._dynamo.is_compiling():
|
|
valid = expected_func == func
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
inp0 = torch.ones(1, 1)
|
|
inp1 = torch.ones(1, 1)
|
|
inp0_int = torch.ones(1, 1, dtype=torch.int32)
|
|
inp1_int = torch.ones(1, 1, dtype=torch.int32)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn_un(op, inp):
|
|
return op(inp)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn_un_int(op, inp):
|
|
return op(inp)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn_bin(op, inp0, inp1):
|
|
return op(inp0, inp1)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn_bin_int(op, inp0, inp1):
|
|
return op(inp0, inp1)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn_tensor_and_int(op, inp0, inp1):
|
|
return op(inp0, inp1)
|
|
|
|
setups_and_oplists = [
|
|
(lambda o: fn_un(o, inp0), un_ops),
|
|
(lambda o: fn_un_int(o, inp0_int), un_int_ops),
|
|
(lambda o: fn_bin(o, inp0, inp1), bin_ops),
|
|
(lambda o: fn_bin_int(o, inp0_int, inp1_int), bin_int_ops),
|
|
(lambda o: fn_tensor_and_int(o, inp0_int, 0), tensor_and_int_ops),
|
|
]
|
|
|
|
# gather the reverse functions
|
|
rsetups_and_oplists = [
|
|
(
|
|
lambda o: fn_bin(o, 1, inp1),
|
|
bin_ops,
|
|
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
|
|
(lambda o: fn_bin_int(o, 1, inp1_int), bin_int_ops),
|
|
(lambda o: fn_tensor_and_int(o, 0, inp0_int), tensor_and_int_ops),
|
|
]
|
|
|
|
skips = {operator.not_} # Has local scalar dense call which graph breaks
|
|
rskips = {
|
|
operator.matmul,
|
|
operator.imatmul,
|
|
operator.getitem,
|
|
} # Doesn't type check with reversed args
|
|
|
|
def run_checks(setups_and_oplists, skips, ref_map):
|
|
nonlocal valid
|
|
nonlocal expected_func
|
|
for setup_fn, op_list in setups_and_oplists:
|
|
for op in op_list:
|
|
if op in skips or op not in ref_map:
|
|
continue
|
|
with FuncEquivMode():
|
|
expected_func = ref_map[op]
|
|
setup_fn(op)
|
|
self.assertTrue(valid)
|
|
|
|
expected_func = None
|
|
valid = False
|
|
|
|
run_checks(setups_and_oplists, skips, BUILTIN_TO_TENSOR_FN_MAP)
|
|
run_checks(rsetups_and_oplists, rskips, BUILTIN_TO_TENSOR_RFN_MAP)
|
|
|
|
def test_expand(self):
|
|
from torch.distributions import (
|
|
AffineTransform,
|
|
ComposeTransform,
|
|
Normal,
|
|
TanhTransform,
|
|
TransformedDistribution,
|
|
)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/141232
|
|
with torch.device("cpu"):
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def func(a):
|
|
d = TransformedDistribution(
|
|
Normal(a, 1),
|
|
ComposeTransform([TanhTransform(), AffineTransform(2, 2)]),
|
|
)
|
|
b = d.log_prob(d.rsample((10,)))
|
|
return b
|
|
|
|
func(torch.randn(3))
|
|
|
|
@requires_gpu
|
|
def test_flex_attention(self):
|
|
import torch
|
|
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
|
|
|
torch.set_default_device(device_type)
|
|
|
|
flex_attention = torch.compile(flex_attention, dynamic=False)
|
|
|
|
prefix_lengths = torch.arange(8)
|
|
|
|
def prefix_lm(b, h, q, kv):
|
|
return prefix_lengths[b] >= kv
|
|
|
|
# This runs in fullgraph already
|
|
create_block_mask(
|
|
prefix_lm, 8, None, 512, 512, _compile=True, device=device_type
|
|
)
|
|
|
|
def test_register_hook(self):
|
|
import functools
|
|
|
|
def my_hook(grad, *, k=0):
|
|
return grad + k
|
|
|
|
hook = functools.partial(my_hook, k=3)
|
|
|
|
class MyMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
x.register_hook(hook)
|
|
y = x.mul(2)
|
|
z = y.mul(3)
|
|
return (z,)
|
|
|
|
mod = MyMod()
|
|
x = torch.ones(4, requires_grad=True)
|
|
|
|
with torch.device("cpu"):
|
|
torch.compile(mod, fullgraph=True)(x)
|
|
|
|
@requires_gpu
|
|
@skipIfXpu(msg="XPU does not support flex attention")
|
|
def test_hop(self):
|
|
import torch
|
|
import torch._higher_order_ops
|
|
from torch.nn.attention.flex_attention import (
|
|
flex_attention as flex_attention_eager,
|
|
)
|
|
|
|
with torch.device(GPU_TYPE):
|
|
flex_attention = torch.compile(flex_attention_eager, dynamic=False)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"raised exception HopDetectionError([ConstantVariable(str: 'test')])",
|
|
):
|
|
# This runs in fullgraph already
|
|
with TestModeRaises():
|
|
flex_attention(
|
|
torch.ones(2, 2, 2, 2),
|
|
torch.ones(2, 2, 2, 2),
|
|
torch.ones(2, 2, 2, 2),
|
|
)
|
|
|
|
@requires_gpu
|
|
@skipIfXpu(msg="XPU does not support flex attention")
|
|
def test_hop_eager(self):
|
|
import torch
|
|
import torch._higher_order_ops
|
|
from torch.nn.attention.flex_attention import (
|
|
flex_attention as flex_attention_eager,
|
|
)
|
|
|
|
with torch.device(GPU_TYPE):
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"raised exception HopDetectionError([ConstantVariable(str: 'test')])",
|
|
):
|
|
with TestModeRaises():
|
|
flex_attention_eager(
|
|
torch.ones(2, 2, 2, 2),
|
|
torch.ones(2, 2, 2, 2),
|
|
torch.ones(2, 2, 2, 2),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|