Files
pytorch/test/dynamo/test_modes.py
Oguz Ulgen dc55704b48 Rename cache limit to recompile limit in configs (#143709)
This PR renames every cache_limit to recompile_limit via sed.

Old config options are maintained via Config(alias='xyz')

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143709
Approved by: https://github.com/jansel
2024-12-22 10:03:57 +00:00

651 lines
19 KiB
Python

# Owner(s): ["module: dynamo"]
import operator
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._C import (
_len_torch_function_stack,
_pop_torch_function_stack,
_push_on_torch_function_stack,
)
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
from torch.testing._internal.triton_utils import requires_cuda
from torch.utils._device import DeviceContext
from torch.utils._python_dispatch import TorchDispatchMode
class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.zeros(2, 2)
return super().__torch_function__(func, types, args, kwargs)
class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
@classmethod
def tearDownClass(cls):
super().tearDownClass()
def test_skip_torch_dispatch_modes(self):
class RewriteAddToMul(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if func is torch.ops.aten.add.Tensor:
func = torch.ops.aten.mul.Tensor
return func(*args, **kwargs)
def fn(x):
return x + x
cnt = torch._dynamo.testing.CompileCounter()
x = torch.tensor([3.0])
with RewriteAddToMul():
eager_res = fn(x)
compiled_res = torch.compile(fn, backend=cnt)(x)
self.assertEqual(eager_res, compiled_res)
self.assertEqual(cnt.frame_count, 0)
class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
cls.default_device_old = torch.get_default_device()
super().setUpClass()
@classmethod
def tearDownClass(cls):
torch.set_default_device(cls.default_device_old)
super().tearDownClass()
def setUp(self):
torch.set_default_device(None)
torch._dynamo.reset()
def tearDown(self):
torch.set_default_device(None)
torch._dynamo.reset()
def _run_torch_function_mode_guard_test(self):
class TestMode1(BaseTorchFunctionMode):
pass
class TestMode2(BaseTorchFunctionMode):
pass
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt.__call__)
def fn(x):
return x + 1
inp = torch.ones(2, 2)
fn(inp)
self.assertEqual(cnt.frame_count, 1)
with TestMode1():
fn(inp)
self.assertEqual(cnt.frame_count, 2)
with TestMode1(), TestMode2():
fn(inp)
self.assertEqual(cnt.frame_count, 3)
with TestMode2(), TestMode1():
fn(inp)
self.assertEqual(cnt.frame_count, 4)
with TestMode1():
fn(inp)
self.assertEqual(cnt.frame_count, 4)
@torch._dynamo.config.patch("enable_cpp_guard_manager", False)
def test_torch_function_mode_guards_py(self):
self._run_torch_function_mode_guard_test()
def test_torch_function_mode_guards_cpp(self):
self._run_torch_function_mode_guard_test()
def test_stack_state_mutation_default_device(self):
m = BaseTorchFunctionMode()
m1 = BaseTorchFunctionMode()
with m, m1:
@torch.compile(fullgraph=True)
def fn(x):
torch.set_default_device("cpu")
_pop_torch_function_stack()
fn(torch.ones(2, 2))
_push_on_torch_function_stack(m1)
stack = _get_current_function_mode_stack()
self.assertIsInstance(stack[0], DeviceContext)
self.assertEqual(stack[0].device, torch.device("cpu"))
self.assertIs(stack[1], m)
self.assertIs(stack[2], m1)
def test_stack_state_clear_default_device(self):
@torch.compile(fullgraph=True)
def fn(x):
torch.set_default_device(None)
return x + 1
fn(torch.ones(2, 2))
stack = _get_current_function_mode_stack()
self.assertEqual(len(stack), 0)
m = BaseTorchFunctionMode()
m1 = BaseTorchFunctionMode()
# Stack populated, add device
with m, m1:
@torch.compile(fullgraph=True)
def fn(x):
torch.set_default_device("cpu")
torch.set_default_device(None)
torch.set_default_device("cpu")
return x + 1
fn(torch.ones(2, 2))
stack = _get_current_function_mode_stack()
self.assertEqual(stack[0].device, torch.device("cpu"))
self.assertIs(stack[1], m)
self.assertIs(stack[2], m1)
# Stack populated, remove device
torch.set_default_device("cpu")
with m, m1:
@torch.compile(fullgraph=True)
def fn(x):
torch.set_default_device(None)
return x + 1
fn(torch.ones(2, 2))
stack = _get_current_function_mode_stack()
self.assertIs(stack[0], m)
self.assertIs(stack[1], m1)
@torch.compile(fullgraph=True)
def fn(x):
torch.set_default_device("cpu")
torch.set_default_device("cpu")
return x + 1
fn(torch.ones(2, 2))
stack = _get_current_function_mode_stack()
self.assertEqual(stack[0].device, torch.device("cpu"))
torch.set_default_device(None)
def test_pop_torch_function_mode(self):
m = BaseTorchFunctionMode()
with m:
@torch.compile(fullgraph=True)
def fn(x):
_pop_torch_function_stack()
return x + 1
fn(torch.ones(2, 2))
self.assertEqual(_len_torch_function_stack(), 0)
# reset stack so __exit__ doesn't crash
_push_on_torch_function_stack(m)
self.assertEqual(_len_torch_function_stack(), 0)
def test_error_empty_stack_pop_torch_function_mode(self):
@torch.compile(fullgraph=True)
def fn(x):
_pop_torch_function_stack()
return x + 1
self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Popping from an empty torch function mode stack",
lambda: fn(torch.ones(2, 2)),
)
def test_push_torch_function_mode(self):
m = BaseTorchFunctionMode()
with m:
@torch.compile(fullgraph=True)
def fn(x, m):
_push_on_torch_function_stack(m)
return x + 1
fn(torch.ones(2, 2), m)
self.assertEqual(_len_torch_function_stack(), 2)
# reset stack state
_pop_torch_function_stack()
self.assertEqual(_len_torch_function_stack(), 0)
def test_len_torch_function_mode(self):
m = BaseTorchFunctionMode()
with m:
@torch.compile(fullgraph=True)
def fn(x):
z = _len_torch_function_stack()
return x + z
res = fn(torch.ones(2, 2))
self.assertEqual(res, torch.ones(2, 2) + 1)
self.assertEqual(_len_torch_function_stack(), 1)
def test_intermedate_torch_function_mode_construction_mutation(self):
class TestMode(BaseTorchFunctionMode):
def __init__(self, x):
self.x = x
@torch.compile(fullgraph=True)
def fn(x):
z = TestMode(2)
z.y = 2
return x + 1, z
fn(torch.ones(2, 2))
def test_torch_function_mode_enabled_guard(self):
cnt = torch._dynamo.testing.CompileCounter()
inp = torch.ones(2, 2)
@torch.compile(backend=cnt.__call__)
def fn(x):
return x + 1
with BaseTorchFunctionMode(), torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
fn(inp)
fn(inp)
self.assertEqual(cnt.frame_count, 2)
def test_nested_torch_function_mode(self):
mode_1_called = False
mode_2_called = False
def reset_state():
nonlocal mode_1_called
nonlocal mode_2_called
mode_1_called = False
mode_2_called = False
ones = torch.ones(2, 2)
zeros = torch.zeros(2, 2)
class TestMode1(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
nonlocal mode_1_called
mode_1_called = True
if func == torch.add:
return zeros
return super().__torch_function__(func, types, args, kwargs)
class TestMode2(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
nonlocal mode_2_called
mode_2_called = True
if func == torch.mul:
return ones
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
def fn_2(x):
return torch.mul(x, 3) + torch.add(x, 3)
inp = torch.ones(2, 2) + 1
for fn_i in [fn, fn_2]:
fn_opt = torch.compile(fn_i, fullgraph=True)
with TestMode1(), TestMode2():
expected = fn_i(inp), mode_1_called, mode_2_called
reset_state()
actual = fn_opt(inp), mode_1_called, mode_2_called
reset_state()
self.assertEqual(expected, actual)
def test_torch_function_mode_disable(self):
class TestSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.ones(2, 2)
return super().__torch_function__(func, types, args, kwargs)
class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.zeros(2, 2)
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
with torch._C.DisableTorchFunctionSubclass():
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
with torch._C.DisableTorchFunction():
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_highest_priority(self):
class TestSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.ones(2, 2)
return super().__torch_function__(func, types, args, kwargs)
def fn(x):
return torch.add(x, 3)
inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)
fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
expected = fn(inp)
actual = fn_opt(inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_enter_exit(self):
def fn(x, y):
with TestMode():
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn, fullgraph=True)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_graph_break(self):
def fn(x, y):
with TestMode():
torch._dynamo.graph_break()
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_and_pop_graph_break(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
def test_torch_function_mode_restore_on_exc(self):
@torch._dynamo.disable()
def err():
raise RuntimeError("test")
@torch.compile()
def fn(x):
with TestMode():
x += 1
err()
x += 2
return x
try:
fn(torch.ones(2, 2))
except RuntimeError:
pass
self.assertEqual(_len_torch_function_stack(), 0)
def test_torch_function_mode_and_pop_graph_break_mutation(self):
def fn(x, y):
with TestMode():
z = _pop_torch_function_stack()
z.y = 5
torch._dynamo.graph_break()
_push_on_torch_function_stack(z)
o = torch.add(x, 3)
o = torch.mul(o, z.y)
return torch.add(o, y)
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
fn_opt = torch.compile(fn)
expected = fn(*inp)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
# Needs larger cache size since we recompile for each op
@patch.object(torch._dynamo.config, "recompile_limit", 48)
def test_builtin_equivalent_funcs(self):
from torch._dynamo.variables.torch_function import (
bin_int_ops,
bin_ops,
BUILTIN_TO_TENSOR_FN_MAP,
BUILTIN_TO_TENSOR_RFN_MAP,
tensor_and_int_ops,
un_int_ops,
un_ops,
)
expected_func = None
valid = False
class FuncEquivMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
nonlocal expected_func
nonlocal valid
if not kwargs:
kwargs = {}
if torch._dynamo.is_compiling():
valid = expected_func == func
return super().__torch_function__(func, types, args, kwargs)
inp0 = torch.ones(1, 1)
inp1 = torch.ones(1, 1)
inp0_int = torch.ones(1, 1, dtype=torch.int32)
inp1_int = torch.ones(1, 1, dtype=torch.int32)
@torch.compile(fullgraph=True)
def fn_un(op, inp):
return op(inp)
@torch.compile(fullgraph=True)
def fn_un_int(op, inp):
return op(inp)
@torch.compile(fullgraph=True)
def fn_bin(op, inp0, inp1):
return op(inp0, inp1)
@torch.compile(fullgraph=True)
def fn_bin_int(op, inp0, inp1):
return op(inp0, inp1)
@torch.compile(fullgraph=True)
def fn_tensor_and_int(op, inp0, inp1):
return op(inp0, inp1)
setups_and_oplists = [
(lambda o: fn_un(o, inp0), un_ops),
(lambda o: fn_un_int(o, inp0_int), un_int_ops),
(lambda o: fn_bin(o, inp0, inp1), bin_ops),
(lambda o: fn_bin_int(o, inp0_int, inp1_int), bin_int_ops),
(lambda o: fn_tensor_and_int(o, inp0_int, 0), tensor_and_int_ops),
]
# gather the reverse functions
rsetups_and_oplists = [
(
lambda o: fn_bin(o, 1, inp1),
bin_ops,
), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int))
(lambda o: fn_bin_int(o, 1, inp1_int), bin_int_ops),
(lambda o: fn_tensor_and_int(o, 0, inp0_int), tensor_and_int_ops),
]
skips = {operator.not_} # Has local scalar dense call which graph breaks
rskips = {
operator.matmul,
operator.imatmul,
operator.getitem,
} # Doesn't type check with reversed args
def run_checks(setups_and_oplists, skips, ref_map):
nonlocal valid
nonlocal expected_func
for setup_fn, op_list in setups_and_oplists:
for op in op_list:
if op in skips or op not in ref_map:
continue
with FuncEquivMode():
expected_func = ref_map[op]
setup_fn(op)
self.assertTrue(valid)
expected_func = None
valid = False
run_checks(setups_and_oplists, skips, BUILTIN_TO_TENSOR_FN_MAP)
run_checks(rsetups_and_oplists, rskips, BUILTIN_TO_TENSOR_RFN_MAP)
def test_expand(self):
from torch.distributions import (
AffineTransform,
ComposeTransform,
Normal,
TanhTransform,
TransformedDistribution,
)
# https://github.com/pytorch/pytorch/issues/141232
with torch.device("cpu"):
@torch.compile(fullgraph=True)
def func(a):
d = TransformedDistribution(
Normal(a, 1),
ComposeTransform([TanhTransform(), AffineTransform(2, 2)]),
)
b = d.log_prob(d.rsample((10,)))
return b
func(torch.randn(3))
@requires_cuda
def test_flex_attention(self):
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
torch.set_default_device("cuda")
flex_attention = torch.compile(flex_attention, dynamic=False)
prefix_lengths = torch.arange(8)
def prefix_lm(b, h, q, kv):
return prefix_lengths[b] >= kv
# This runs in fullgraph already
create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
def test_register_hook(self):
import functools
def my_hook(grad, *, k=0):
return grad + k
hook = functools.partial(my_hook, k=3)
class MyMod(torch.nn.Module):
def forward(self, x):
x.register_hook(hook)
y = x.mul(2)
z = y.mul(3)
return (z,)
mod = MyMod()
x = torch.ones(4, requires_grad=True)
with torch.device("cpu"):
torch.compile(mod, fullgraph=True)(x)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()