Files
pytorch/test/dynamo/test_subclasses.py
Edward Z. Yang aec6332356 Only thunkify proxies in some situations (#132421)
The goal of this PR is to avoid stack overflow when we create extremely long chains of thunks, and then evaluate them (e.g., as occurs if you sum(long list of symint)). The basic idea behind this PR is to only thunkify proxies if they're being created in places where they may or may not be used--crucially, symint operations that occur in user code we are tracing are eagerly placed into the graph, even if they may eventually be dead.

I annotated the PR with explanation of changes.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132421
Approved by: https://github.com/Skylion007, https://github.com/zou3519
ghstack dependencies: #132674, #132675
2024-08-08 12:03:06 +00:00

2380 lines
86 KiB
Python

# Owner(s): ["module: dynamo"]
import functools
import itertools
import unittest
from functools import partial
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._functorch.config
import torch.utils._pytree as pytree
import torch.utils.checkpoint
from torch._dynamo.testing import normalize_gm
from torch._higher_order_ops.wrap import wrap
from torch.fx.experimental.symbolic_shapes import (
DimDynamic,
ShapeEnv,
StatelessSymbolicContext,
)
from torch.nested._internal.nested_tensor import (
jagged_from_list,
jagged_from_tensor_and_lengths,
nested_view_from_values_offsets,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
NestedTensorTestCase,
parametrize,
subtest,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.two_tensor import TwoTensor
def traceable_subclass(c):
return torch._dynamo.config.patch("traceable_tensor_subclasses", {c})
def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2)
self.assertEqual(actual_recompiles, expected_recompiles)
def get_jagged_tensor(nested_size, offsets, requires_grad=True):
# Makes a jagged tensor with N constituent tensors with size
# as specified ((S0, S1, S2), D)
D = nested_size[1]
out = []
for s in nested_size[0]:
out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64))
return jagged_from_list(out, offsets)
def get_view_test_cases():
# Test all cases with both an NT base and a dense base
# Subclass -> Subclass
# Dense -> Subclass
# NB: Don't close over loop variables, they will not get copied into the
# closure
#
# NB: These return functions so we don't generate tensors during test
# collection time
def mk_basic(base_is_nt):
# There are three cases to consider here based on the logic in
# meta_utils.py
#
# (1) basic case:
# view is not a leaf and has the same requires grad as its basic case
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
x = x.clone() if base_is_nt else x
assert not x.is_leaf
return x.unsqueeze(-1)
def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2):
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1)
x = x.clone() if base_is_nt else x
with torch.no_grad():
x_view = x.unsqueeze(-1)
# The issue is this doesn't quite work
x_view.requires_grad_(requires_grad_2)
return x_view
def mk_obscure(base_is_nt):
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
x = x.clone() if base_is_nt else x
# intermediate leaf view
with torch.no_grad():
x_view = x.unsqueeze(-1)
x_view.requires_grad_(True)
x_view_view = x_view.unsqueeze(-1)
return x_view_view
for base_is_nt in [False, True]:
prefix = f"base_is_nt_{base_is_nt}"
yield partial(mk_basic, base_is_nt), f"{prefix}_basic"
# (2) leaf view case:
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
# base w/ requires_grad True or requires_grad False
for requires_grad_1, requires_grad_2 in itertools.product(
[True, False], repeat=2
):
yield partial(
mk_leaf, base_is_nt, requires_grad_1, requires_grad_2
), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}"
# (3) obscure case:
# view is not a leaf (implies requires_grad True)
# base w/ requires_grad False)
yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
# Subclass -> Dense
yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[
0
].clone(), "subclass_dense"
# Dense -> Subclass -> Dense -> Subclass
def mk_dense_subclass_dense_subclass():
values = torch.randn(10, 5)
offsets = torch.tensor([0, 3, 6, 10])
offsets2 = offsets.clone().detach()
return nested_view_from_values_offsets(
nested_view_from_values_offsets(values, offsets).values(), offsets
)
yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass"
def mk_subclass_dense_subclass_dense():
x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
offsets2 = x.offsets().clone().detach()
nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense"
VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()}
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
compile_full_eager = torch.compile(backend="eager", fullgraph=True)
class BaseTorchFunction(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
class MockSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
class AttrSubclass(torch.Tensor):
x: int = 10
size: int = 10
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
class DummyNDim(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func == torch.Tensor.ndim.__get__:
return 10
return super().__torch_function__(func, types, args, kwargs)
class WrapperSubclass:
def __init__(self, tensor):
self.tensor = tensor
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args)
kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs)
return func(*args, **kwargs)
class SigmoidToExpSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func == torch.Tensor.sigmoid:
return super().__torch_function__(torch.Tensor.exp, types, args, kwargs)
return super().__torch_function__(func, types, args, kwargs)
# Wrapper subclass with two inner tensors: data and scale
# data has same shape as outer, and scale has single dim size
class ScaledTensor(torch.Tensor):
def __new__(
cls,
data: torch.Tensor,
scale: torch.Tensor,
*,
constant: int = 0,
):
return torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=data.dtype,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0):
self._data = data
self._scale = scale
self._constant = constant
def __tensor_flatten__(self):
ctx = {"_constant": self._constant}
return ["_data", "_scale"], ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
assert len(inner_tensors) == 2
return ScaledTensor(
inner_tensors["_data"],
inner_tensors["_scale"],
constant=metadata["_constant"],
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
scaled_tensor = args[0]
out = func(scaled_tensor._data, *args[1:], **kwargs)
return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant)
def __repr__(self):
return f"{self._data.__repr__()}\n{self._scale.__repr__()}"
class OptionalScaledTensor(torch.Tensor):
def __new__(
cls,
data,
scale,
*,
constant: int = 0,
):
return torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=data.dtype,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
def __init__(self, data: torch.Tensor, scale, constant: int = 0):
self._data = data
self._scale = scale
self._constant = constant
def __tensor_flatten__(self):
ctx = {"_constant": self._constant}
if self._scale is not None:
return ["_data", "_scale"], ctx
else:
return ["_data"], ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
return OptionalScaledTensor(
inner_tensors["_data"],
inner_tensors["_scale"] if "_scale" in inner_tensors else None,
constant=metadata["_constant"],
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
scaled_tensor = args[0]
out = func(scaled_tensor._data, *args[1:], **kwargs)
if scaled_tensor._scale is not None:
out = out * scaled_tensor._scale
return OptionalScaledTensor(
out, scaled_tensor._scale, constant=scaled_tensor._constant
)
def __repr__(self):
return (
f"OptionalScaledTensor({self._data.__repr__()}\n{self._scale.__repr__()})"
)
class CtxSubclassTensor(torch.Tensor):
"""
Class used to verify guarding on the subclass metadata
"""
@staticmethod
def __new__(cls, a, constant):
shape = a.shape
kwargs = {}
kwargs["strides"] = a.stride()
kwargs["storage_offset"] = a.storage_offset()
kwargs["device"] = a.device
kwargs["layout"] = a.layout
kwargs["requires_grad"] = a.requires_grad
kwargs["dtype"] = a.dtype
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
return out
def __init__(self, a, constant):
self.a = a
self.constant = constant
def __repr__(self):
a_repr = repr(self.a)
return f"CtxSubclassTensor({a_repr})"
def __tensor_flatten__(self):
return ["a"], (self.constant,)
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, sizes, strides):
constant = meta[0]
a = inner_tensors["a"]
return CtxSubclassTensor(a, constant)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
from torch.utils._python_dispatch import return_and_correct_aliasing
if kwargs is None:
kwargs = {}
biggest_constant = max(
[
x.constant
for x in pytree.tree_flatten(args)[0]
if isinstance(x, CtxSubclassTensor)
]
)
args_a = pytree.tree_map(
lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, args
)
kwargs_a = pytree.tree_map(
lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, kwargs
)
out_a = func(*args_a, **kwargs_a)
out = pytree.tree_map(
lambda x: CtxSubclassTensor(x, biggest_constant)
if isinstance(x, torch.Tensor)
else x,
out_a,
)
if func == torch.ops.aten.mul.Tensor:
out = out + out.constant
return return_and_correct_aliasing(func, args, kwargs, out)
def func(a):
return a.sin()
class EagerRecordGraphAndInputs:
def __init__(self) -> None:
self.graphs = []
self.example_inputs = []
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
self.graphs.append(gm)
self.example_inputs.append(example_inputs)
return gm
GLOBAL_TEST_SUBCLASSES = {
MockSubclass,
DummyNDim,
SigmoidToExpSubclass,
BaseTorchFunction,
}
# Returns True if the function recompiles between inputs1 and inputs2 with the
# specified dynamic setting.
def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
compile_count = [0]
def counter(gm, example_inputs):
compile_count[0] += 1
return gm
compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
compiled_f(*inputs1)
compiled_f(*inputs2)
return compile_count[0] > 1
class SubclassTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
torch._dynamo.config.patch(
"traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES
)
)
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
_check_recompiles(self, fn, inputs1, inputs2, expected_recompiles)
def test_no_call_to_new(self):
class BadNewTorchFunction(torch.Tensor):
def __new__(cls, *args, **kwargs):
raise RuntimeError("Oops!")
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {BadNewTorchFunction}
):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return torch.add(x, 1)
input = torch.ones(2, 2).as_subclass(BadNewTorchFunction)
res = fn(input)
self.assertIsInstance(res, BadNewTorchFunction)
def test_no_torch_function_recompiles(self):
class NJT:
def __repr__(self):
return f"NJT(shape={self.shape})"
def __init__(self, values, offsets):
self._values = values
self._offsets = offsets
def sin(self):
return torch.sin(self)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func == torch.sin:
self = args[0]
return NJT(func(self._values), self._offsets)
raise AssertionError("should not get here")
values1 = torch.randn(10, 3, 4, requires_grad=True)
values2 = torch.randn(10, 3, 4, requires_grad=True)
offsets = torch.tensor([0, 3, 10])
njt1 = NJT(values1, offsets)
njt2 = NJT(values2, offsets)
@torch.compile(backend="eager", fullgraph=True)
def f(x):
return torch.sin(x)
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
f(njt1)
f(njt2)
def test_base_torch_function_tracing(self):
def fn(x):
return torch.add(x, 1)
input = torch.ones(2, 2).as_subclass(BaseTorchFunction)
out = fn(input)
out_opt = compile_full_eager(fn)(input)
self.assertIsInstance(out, BaseTorchFunction)
self.assertEqual(out, out_opt)
def test_torch_function_state_graph_break(self):
@torch.compile(backend="eager")
def fn(x):
with torch._C.DisableTorchFunctionSubclass():
torch._dynamo.graph_break()
return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
input = torch.ones(2, 2)
res, _ = fn(input)
self.assertFalse(res)
def test_torch_function_state_nested(self):
@torch.compile(backend="eager")
def fn(x):
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunctionSubclass():
x = x + 1
# Should reset to the outer state (disabled) after exiting ctx manager
return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
input = torch.ones(2, 2)
res, _ = fn(input)
self.assertFalse(res)
def test_torch_function_state_tracing(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
with torch._C.DisableTorchFunctionSubclass():
torch.add(x, 1.0)
input = torch.ones(2, 2)
res = fn(input)
def test_torch_function_state_guards(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(x):
torch.add(x, 1.0)
input = torch.ones(2, 2)
with torch._C.DisableTorchFunctionSubclass():
res = fn(input)
res = fn(input)
self.assertEqual(cnt.frame_count, 2)
def test_return_subclass(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return MockSubclass(torch.add(x, 1.0))
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, MockSubclass)
def test_return_as_subclass(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return torch.add(x, 1.0).as_subclass(MockSubclass)
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, MockSubclass)
def test_return_local_subclass(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return LocalSubclass(torch.add(x, 1.0))
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, LocalSubclass)
def test_torch_function_list_args(self):
HANDLED_FUNCTIONS = {}
class MyClass:
def __init__(self, foo):
self.foo = foo
@classmethod
def __torch_function__(
cls,
func,
types,
args=(),
kwargs=None,
):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all( # noqa: C419
[ # noqa: C419
issubclass(t, (torch.Tensor, MyClass)) for t in types
]
):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
def _stack(input, dim=0, *, out=None):
return MyClass(sum([x.foo for x in input]))
HANDLED_FUNCTIONS[torch.stack] = _stack
@torch.compile(backend="eager", fullgraph=True)
def fn(v0, v1):
return torch.stack([v0, v1])
ret = fn(MyClass(1), MyClass(1))
self.assertEqual(ret.foo, 2)
@parametrize(
"comparison",
[
subtest(isinstance, "isinstance"),
subtest(lambda instance, type_: type(instance) == type_, "equality"),
subtest(lambda instance, type_: type(instance) is type_, "identity"),
],
)
@parametrize(
"input_type",
[
subtest(torch.Tensor, "tensor"),
subtest(DummyNDim, "subclass"),
],
)
def test_type_check(self, comparison, input_type):
with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}):
def fn(x):
if comparison(x, DummyNDim):
return torch.ones(1, 1)
else:
return torch.zeros(2, 2)
input = torch.ones(2, 2).as_subclass(input_type)
exp_res = fn(input)
act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input)
self.assertEqual(exp_res, act_res)
def test_torch_function_call_on_method(self):
x = torch.ones(2, 2)
y = torch.ones(2, 2)
z = torch.ones(2, 2)
wrapped = x.as_subclass(SigmoidToExpSubclass)
wrapped2 = y.as_subclass(SigmoidToExpSubclass)
def fn(w):
return w.sigmoid()
fn_opt = compile_full_eager(fn)
res_exp = fn(wrapped)
res_act = fn_opt(wrapped2)
res_exp2 = z.exp()
self.assertEqual(res_exp, res_act)
self.assertEqual(res_exp, res_exp2)
def test_user_overidden_method_unsupported(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
def sigmoid(self):
return None
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
x.sigmoid()
msg = (
"Accessing overridden method/attribute sigmoid on a tensor"
" subclass with a __torch_function__ override is not supported"
)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {LocalSubclass}
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
def test_user_overidden_attr_unsupported(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
ndim = 10
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return x.ndim
msg = (
"Accessing overridden method/attribute ndim on a tensor"
" subclass with a __torch_function__ override is not supported"
)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {LocalSubclass}
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
def test_user_overidden_property_unsupported(self):
class LocalSubclass(torch.Tensor):
def __init__(self) -> None:
self._ndim = 10
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
@property
def ndim(self):
return self._ndim
@ndim.setter
def ndim(self, value):
self._ndim = value
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return x.ndim
msg = (
"Accessing overridden method/attribute ndim on a tensor"
" subclass with a __torch_function__ override is not supported"
)
with torch._dynamo.config.patch(
"traceable_tensor_subclasses", {LocalSubclass}
), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
def test_overridden_method_guarding(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
@torch.compile(backend="eager")
def fn(x):
return x.sigmoid()
with torch._dynamo.config.patch(
error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass}
):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
fn(x)
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)
with torch._dynamo.config.patch(
traceable_tensor_subclasses={LocalSubclass}
), self.assertRaisesRegex(
TypeError,
"'bool' object is not callable",
):
LocalSubclass.sigmoid = False
fn(x)
def test_torch_function_call_on_attr(self):
x = torch.ones(2, 2)
wrapped = x.as_subclass(DummyNDim)
def fn(w):
return w.ndim + torch.ones(2)
fn_opt = compile_full_eager(fn)
res_exp = fn(wrapped)
res_act = fn_opt(wrapped)
self.assertEqual(res_exp, res_act)
self.assertEqual(res_exp, torch.ones(2) + 10)
def test_torch_function_wrapper_class(self):
x = torch.ones(2, 2)
wrapped = WrapperSubclass(x)
def fn(w):
return torch.add(w, 1.0)
fn_opt = compile_full_eager(fn)
res_exp = fn(wrapped)
res_act = fn_opt(wrapped)
self.assertEqual(res_exp, res_act)
def test_torch_function_wrapper_class_with_kwargs(self):
x = torch.ones(2, 2)
wrapped = WrapperSubclass(x)
def fn(w):
return torch.add(w, 1.0, alpha=2.0)
fn_opt = compile_full_eager(fn)
res_exp = fn(wrapped)
res_act = fn_opt(wrapped)
self.assertEqual(res_exp, res_act)
def test_tensor_subclass_custom_attr(self):
class AttrSubclass(torch.Tensor):
x: int = 10
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return x.x + torch.ones(2, 2)
with traceable_subclass(AttrSubclass):
input = torch.ones(2, 2).as_subclass(AttrSubclass)
fn_opt = compile_full_eager(fn)
res_exp = fn(input)
res_act = fn_opt(input)
self.assertEqual(res_exp, res_act)
def test_compile_with_fake_tensor_dynamic_dim(self):
x = torch.randn([3, 4])
def f(x):
return torch.sin(x)
def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count):
torch._dynamo.reset()
cnt = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
x1 = torch.rand_like(x)
f(x)
f(torch.randn([4, 3]))
shape_env = ShapeEnv()
with torch._subclasses.fake_tensor.FakeTensorMode(
shape_env=shape_env
) as fake_mode:
x_fake = fake_mode.from_tensor(
x,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[dim_dynamic for i in range(x.dim())]
),
)
x1_fake = fake_mode.from_tensor(
x1,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[dim_dynamic for i in range(x.dim())]
),
)
opt_f(x_fake)
opt_f(x1_fake)
self.assertEqual(cnt.frame_count, exp_frame_count)
self.assertEqual(cnt.op_count, exp_op_count)
test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1)
test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1)
test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1)
def test_compile_with_fake_tensor_automatic_dynamic(self):
def f(x):
return torch.sin(x)
def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count):
torch._dynamo.reset()
cnt = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
shape_env = ShapeEnv()
with torch._subclasses.fake_tensor.FakeTensorMode(
shape_env=shape_env
) as fake_mode:
for inp in inps:
fake_inp = fake_mode.from_tensor(
inp,
symbolic_context=StatelessSymbolicContext(
[dim_dynamic for i in range(x.dim())]
),
)
opt_f(fake_inp)
self.assertEqual(cnt.frame_count, exp_frame_count)
self.assertEqual(cnt.op_count, exp_op_count)
x = torch.randn([3, 4])
y = torch.randn([4, 5])
z = torch.randn([5, 6])
a = torch.randn([3, 5])
b = torch.randn([4, 4])
# When inputs' DimDynamic is DYNAMIC or DUCK, the inputs
# to opt_f will be tensors with SymInt sizes. Dynamo will treat input
# as dynamic automatically and will only compile once
for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]:
test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1)
test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1)
test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1)
for dim_dynamic in [DimDynamic.STATIC]:
# Recompile once, first with dim 0 and 1 become Dynamic
test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2)
# Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic.
test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3)
# Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic.
test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3)
def test_compile_with_functionalization(self):
x = torch.randn([3, 4])
x_clone = x.clone()
x_clone2 = x.clone()
backend = EagerRecordGraphAndInputs()
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return x.add_(1.0) + torch.nn.functional.relu_(x)
f_out = f(x)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 3)
self.assertEqual(len(backend.graphs), 1)
self.assertEqual(len(backend.example_inputs), 1)
actual = normalize_gm(backend.graphs[0].print_readable(print_output=False))
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 4]"):
l_x_ = L_x_
add_: "f32[3, 4]" = l_x_.add_(1.0)
relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None
add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None
return (add,)
""",
)
ff = torch.func.functionalize(f)
ff_out = ff(x_clone)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 6)
self.assertEqual(len(backend.graphs), 2)
self.assertEqual(len(backend.example_inputs), 2)
actual = normalize_gm(backend.graphs[1].print_readable(print_output=False))
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 4]"):
l_x_ = L_x_
add_: "f32[3, 4]" = l_x_.add_(1.0)
relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None
add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None
return (add,)
""",
)
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
# Cannot re-use the version from AOTAutograd, since that uses python functional tensors.
def to_fun(x):
x_functional = torch._to_functional_tensor(x)
torch._mirror_autograd_meta_to(x, x_functional)
return x_functional
def aot_f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
func_args = pytree.tree_map(to_fun, args)
func_kwargs = pytree.tree_map(to_fun, kwargs)
return func(*func_args, **func_kwargs)
finally:
torch._disable_functionalization()
return wrapper
aot_ff = aot_f_wrapper(f)
aot_ff_out = aot_ff(x_clone2)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 9)
self.assertEqual(len(backend.graphs), 3)
self.assertEqual(len(backend.example_inputs), 3)
actual = normalize_gm(backend.graphs[2].print_readable(print_output=False))
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 4]"):
l_x_ = L_x_
add_: "f32[3, 4]" = l_x_.add_(1.0)
relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None
add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None
return (add,)
""",
)
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
self.assertEqual(f_out, ff_out)
self.assertEqual(f_out, aot_ff_out)
try:
torch._enable_functionalization(reapply_views=False)
xf = pytree.tree_map(to_fun, x)
x_view = xf.t()
with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"):
f(x_view)
finally:
torch._disable_functionalization()
def test_compile_higher_order_with_functionalization(self):
backend = EagerRecordGraphAndInputs()
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return wrap(lambda x: x.add_(1.0), x)
def check_count_and_graph(
exp_frame_count, exp_op_count, exp_n_graph, exp_graph
):
self.assertEqual(cnt.frame_count, exp_frame_count)
self.assertEqual(cnt.op_count, exp_op_count)
self.assertEqual(len(backend.graphs), exp_n_graph)
actual = normalize_gm(
backend.graphs[exp_n_graph - 1].print_readable(print_output=False)
)
self.assertExpectedInline(actual, exp_graph, skip=1)
t = torch.randn([3, 4])
t_clone = t.clone()
t_clone2 = t.clone()
f(t)
check_count_and_graph(
1,
2,
1,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 4]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
getitem: "f32[3, 4]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3, 4]"):
add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None
return (add_,)
""",
)
ff = torch.func.functionalize(f)
ff_out = ff(t_clone)
# frame count and op count are incremented due to re-compilation
check_count_and_graph(
2,
4,
2,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 4]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
getitem: "f32[3, 4]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3, 4]"):
add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None
return (add_,)
""",
)
try:
x = torch._to_functional_tensor(t_clone2)
torch._mirror_autograd_meta_to(t_clone2, x)
torch._enable_functionalization(reapply_views=False)
aot_f_out = f(x)
finally:
torch._disable_functionalization()
# frame count and op count are incremented due to re-compilation
check_count_and_graph(
3,
6,
3,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 4]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
getitem: "f32[3, 4]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3, 4]"):
add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None
return (add_,)
""",
)
def test_has_torch_function(self):
class MyTensor:
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.max:
return torch.tensor(123)
return func(*args, **kwargs)
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
def fn(x):
return torch.overrides.has_torch_function_unary(
x
), torch.overrides.has_torch_function_variadic(x)
for test_class in [MyTensor, LocalSubclass]:
x = test_class()
ref0 = fn(x)
ref1 = fn(4)
opt_fn = torch._dynamo.optimize("eager")(fn)
res0 = opt_fn(x)
res1 = opt_fn(4)
self.assertEqual(ref0, res0)
self.assertEqual(ref1, res1)
def test_wrapper_subclass_guards_on_inner_tensor(self):
# Holds an inner tensor, that has a distinct shape from the outer wrapper tensor.
# Also adds additional guards on the inner tensor's sizes.
# When the first input to an op has x.shape[0] > 5, we insert an extra add node.
class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor):
@staticmethod
def __new__(cls, inner):
# Double the outer-most dimension
outer_shape = (inner.shape[0] * 2,) + inner.shape[1:]
return torch.Tensor._make_wrapper_subclass(
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
# Calling the overload that has kwargs causes us to go down the first overload path,
# which will **always** specialize sizes.
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
cls,
outer_shape,
inner.stride(),
None,
None,
inner.dtype,
inner.layout,
inner.device,
False,
inner.requires_grad,
)
def __init__(self, inner):
self.inner_elem = inner
def __tensor_flatten__(self):
return ["inner_elem"], None
@staticmethod
def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride):
return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"])
def __repr__(self):
return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
args_inner = torch.utils._pytree.tree_map_only(
DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args
)
out_inner = func(*args_inner, **kwargs)
# Add guards on the inner tensor's sizes
if args_inner[0].shape[0] > 3:
out_inner += 2
return DoubleSizeMaybeAddGeThreeTensor(out_inner)
curr_var_to_val = None
curr_var_to_sources = None
guards = None
def backend(gm, args):
context = torch._guards.TracingContext.get()
# Grab info on sources and guards from the shapeenv
nonlocal curr_var_to_val
nonlocal curr_var_to_sources
nonlocal guards
guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
curr_var_to_val = {
str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items()
}
curr_var_to_sources = {
str(k): v[0].name()
for k, v in context.fake_mode.shape_env.var_to_sources.items()
}
return gm
@torch.compile(backend=backend)
def fn(x):
if x.shape[0] < 13:
return torch.mul(x, x)
else:
return torch.div(x, x)
inp = torch.ones(4, 4)
x = DoubleSizeMaybeAddGeThreeTensor(inp)
torch._dynamo.mark_dynamic(x, 0)
res = fn(x)
# During fakeifying, we end up allocating a separate symint
# for the outer and inner tensor (in this test, s0 is unused).
expected_var_to_val = {
"s0": 8,
"s1": 4,
}
expected_var_to_sources = {
"s0": "L['x'].size()[0]",
"s1": "L['x'].inner_elem.size()[0]",
}
self.assertEqual(curr_var_to_val, expected_var_to_val)
self.assertEqual(curr_var_to_sources, expected_var_to_sources)
self.assertExpectedInline(
"\n".join(guards),
"""\
Eq(2*s1, s0)
2*s1 < 13
s1 > 3""",
)
def test_wrapper_subclass_with_same_sized_inner_tensor(self):
# shouldn't recompile for different sizes when dynamic=True
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7))
self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True))
# should recompile for different data size when dynamic=False
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
# avoid recompile using manual mark_dynamic() for different data size
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
# NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size
torch._dynamo.mark_dynamic(sub1, 0)
torch._dynamo.mark_dynamic(sub1, 1)
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
def test_wrapper_subclass_with_differently_sized_inner_tensor(self):
# should recompile for different scale size when dynamic=False
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
# still recompiles using manual mark_dynamic() on outer for different scale size
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
# NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size
torch._dynamo.mark_dynamic(sub1, 0)
torch._dynamo.mark_dynamic(sub1, 1)
sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
def test_recompiles_with_optional_inner_tensor(self):
def f(x):
return x + 1
# sub1 does not have the optional tensor specified while sub2 does
sub1 = OptionalScaledTensor(torch.randn(2, 4), None)
sub2 = OptionalScaledTensor(torch.randn(2, 4), torch.randn(2, 4))
# sanity check; don't recompile for same input
self.assertFalse(_recompiles_for_inputs(f, (sub1,), (sub1,), dynamic=True))
self.assertFalse(_recompiles_for_inputs(f, (sub2,), (sub2,), dynamic=True))
# these should recompile; optional tensor changes between specified and unspecified
self.assertTrue(_recompiles_for_inputs(f, (sub1,), (sub2,), dynamic=True))
self.assertTrue(_recompiles_for_inputs(f, (sub2,), (sub1,), dynamic=True))
f_compiled = torch.compile(f, backend="aot_eager")
self.assertEqual(f(sub1)._data, f_compiled(sub1)._data)
self.assertEqual(f(sub2)._data, f_compiled(sub2)._data)
def test_torch_dispatch_subclass_guard_recompile(self):
x = torch.ones(2, 2)
x_two = TwoTensor(x.clone(), x.clone())
def fn(w):
return torch.add(w, 1.0)
fn_opt = torch.compile(backend="eager")(fn)
ref = fn(x_two)
res = fn_opt(x_two)
self.assertEqual(ref, res)
# ensure no recompilation on same input type
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn_opt(TwoTensor(x + 1, x + 2))
# recompile!
ref = fn(x)
res = fn_opt(x)
self.assertEqual(ref, res)
def test_tensor_subclass_ctx_guards(self):
x = CtxSubclassTensor(torch.ones(2), 3)
x2 = CtxSubclassTensor(torch.ones(2), 3)
x3 = CtxSubclassTensor(torch.ones(2), 4)
_check_recompiles(self, lambda x: x * x, (x,), (x2,), False)
_check_recompiles(self, lambda x: x * x, (x,), (x3,), True)
def test_tensor_subclass_ctx_recursive_guards(self):
x0 = torch.ones(2, 2)
x1 = CtxSubclassTensor(x0.clone(), 2)
x2 = CtxSubclassTensor(x0.clone(), 3)
tt0 = TwoTensor(x0.clone(), x1)
tt1 = TwoTensor(x0.clone(), x2)
_check_recompiles(self, lambda x: x * x, (tt0,), (tt1,), True)
def test_tensor_subclass_ctx_custom_guards_override(self):
class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
@classmethod
def __metadata_guard__(cls, orig_data, other):
return orig_data[0] <= other[0]
x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 2)
x2 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
x3 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 1)
_check_recompiles(self, lambda x: x * x, (x,), (x2,), False)
_check_recompiles(self, lambda x: x * x, (x,), (x3,), True)
def test_tensor_subclass_ctx_custom_guards_error_arg_num(self):
import torch._dynamo.exc
class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
@classmethod
def __metadata_guard__(cls, y):
# Shouldn't reach here
return False
x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
self.assertRaisesRegex(
torch._dynamo.exc.InternalTorchDynamoError,
"Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments",
lambda: torch.compile(lambda x: x * x)(x),
)
def test_tensor_subclass_ctx_custom_guards_error_not_classmethod(self):
import torch._dynamo.exc
class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
def __metadata_guard__(self, x, y):
return False
x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
self.assertRaisesRegex(
torch._dynamo.exc.InternalTorchDynamoError,
"Tensor subclass method __metadata_guard__ must be a classmethod",
lambda: torch.compile(lambda x: x * x)(x),
)
def test_torch_function_subclass_survives_into_aot_autograd(self):
# If you have a tensor subclass that relies on dispatch into the same op
# without unwrapping and calling torch._C.DisableTorchFunctionSubclass(),
# the torch function-ness will survive into AOTAutograd. Today, NestedTensor
# actually relies on this behavior! Because that torch function logic
# runs during AOTAutograd, this test tests that there is no logic below
# that relies torch function that gets unexpectedly disabled after we
# redispatch from the subclass's torch function.
class SubTensor(torch.Tensor):
@staticmethod
def __new__(cls, t):
return torch.Tensor._make_wrapper_subclass(
cls,
t.shape,
t.stride(),
t.storage_offset(),
torch.contiguous_format,
t.dtype,
torch.strided,
t.device,
False,
t.requires_grad,
"sizes",
False,
False,
None,
)
def __init__(self, t):
super().__init__()
self._t = t
def __tensor_flatten__(self):
return ["_t"], {}
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
t = inner_tensors["_t"]
return SubTensor(t)
def __repr__(self):
return f"SubTensor({self._t})"
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args)
output = func(*new_args, **kwargs)
output = pytree.tree_map_only(
torch.Tensor, lambda t: SubTensor(t), output
)
return output
@torch.compile(dynamic=True)
def f(x):
return x.unflatten(-1, [2, 5])
s = SubTensor(torch.randn(3, 10))
f(s)
# Guard validation upsets the guard
# https://github.com/pytorch/pytorch/issues/129936
@unittest.expectedFailure
def test_recompile_with_symbool_inputs(self):
def f(pred: bool):
if pred:
return torch.ones([3, 4])
else:
return torch.ones([4, 3])
def test_recompilation(
f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards
):
torch._dynamo.reset()
shape_env = ShapeEnv()
backend = torch._dynamo.testing.EagerAndRecordGraphs()
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
f_cond = torch.compile(f, backend=cnt, fullgraph=True)
with torch._subclasses.fake_tensor.FakeTensorMode(
shape_env=shape_env
) as fake_mode:
fake_inp = fake_mode.from_tensor(
x,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())]
),
)
for i, size in enumerate(sizes):
pred = fake_inp.size(0) == size
f_cond(pred)
actual = normalize_gm(
backend.graphs[exp_frame_count[i] - 1].print_readable(
print_output=False
)
)
actual_guard_str = [str(guard.expr) for guard in shape_env.guards]
self.assertExpectedInline(actual, exp_graphs[i])
self.assertEqual(cnt.frame_count, exp_frame_count[i])
self.assertEqual(actual_guard_str, exp_shape_env_guards[i])
true_graph = """\
class GraphModule(torch.nn.Module):
def forward(self):
ones: "f32[3, 4]" = torch.ones([3, 4])
return (ones,)
"""
false_graph = """\
class GraphModule(torch.nn.Module):
def forward(self):
ones: "f32[4, 3]" = torch.ones([4, 3])
return (ones,)
"""
test_recompilation(
f,
torch.randn([3, 4]),
[3, 3, 4, 5],
exp_graphs=[true_graph, true_graph, false_graph, false_graph],
exp_frame_count=[1, 1, 2, 2],
exp_shape_env_guards=[
[],
# s0 is specialized and guarded in outter shape_env when dynamo checks the guards
["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
[
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
"Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
],
[
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
"Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
],
],
)
test_recompilation(
f,
torch.randn([3, 4]),
[4, 5, 3, 3],
exp_graphs=[false_graph, false_graph, true_graph, true_graph],
exp_frame_count=[1, 1, 2, 2],
exp_shape_env_guards=[
[],
# s0 is specialized and guarded in outter shape_env when dynamo checks the guards
["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
[
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
],
[
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
],
],
)
def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self):
def f(x_subclass):
tmp_subclass = torch.add(x, 1)
return torch.mul(tmp_subclass._scale, tmp_subclass._constant)
x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2)
out_ref = f(x)
out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
self.assertEqual(out_ref, out_test)
def test_support_bases(self):
import abc
import torch.fx._symbolic_trace
class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta):
def __new__(cls, name, bases, dct):
x = super().__new__(cls, name, bases, dct)
x.attr = 100
return x
class Multistreamable(abc.ABC): # noqa: B024
pass
class Foo(Multistreamable, metaclass=Meta):
pass
@torch.compile(backend="eager", fullgraph=True)
def f(x):
typ = type(Foo())
typ.__bases__
return typ.__bases__
self.assertEqual(f(torch.randn(1)), (Multistreamable,))
@parametrize("dynamic", [False, True])
def test_subclass_views(self, dynamic):
def _get_views(t): # returns (view: Tensor, expects_raises_false)
# Note that any closed-over SymInts will be symbolicized during fake-ification.
yield t.narrow(dim=-1, start=3, length=8), False
yield t.split(5, -1)[2], False
yield t.split_with_sizes([9, 6], -1)[1], False
yield t.unsqueeze(-1).expand(4, 15, 10), False
yield t.select(-1, 6), False
# https://github.com/pytorch/pytorch/issues/128649
yield t[2:3, 5:9], dynamic
yield t.view(-1, 15), False
def f(x):
return x * 2
compiled_f = torch.compile(
f, backend="aot_eager", fullgraph=True, dynamic=dynamic
)
# Take a view of a subclass to pass as input.
t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15))
for view, expects_raises in _get_views(t):
torch._dynamo.reset()
out_ref = f(view)
if expects_raises:
with self.assertRaises(AssertionError):
out_test = compiled_f(view)
else:
out_test = compiled_f(view)
self.assertEqual(out_ref, out_test)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_mark_static_with_subclass_desugaring(self):
from typing import Any, Callable, Dict, List, Optional
from torch._dynamo.decorators import mark_static_address
from torch._inductor.compile_fx import compile_fx
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
from torch._inductor.utils import BoxedBool
x_inner = torch.ones(4)
x = TwoTensor(x_inner, x_inner)
mark_static_address(x, guard=False)
def inner_compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None,
static_input_idxs: Optional[List[int]] = None,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
aot_mode: bool = False,
is_inference: bool = False,
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
user_visible_outputs: Optional[Dict[str, None]] = None,
layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None,
):
self.assertEqual(static_input_idxs, [1, 2])
return gm
compiler = functools.partial(compile_fx, inner_compile=inner_compile)
@torch.compile(backend=compiler)
def fn(t0, t1, t2):
return t0 + t1 + t2 + 2
fn(torch.ones(4), x, torch.ones(4))
instantiate_parametrized_tests(SubclassTests)
class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
return get_jagged_tensor(nested_size, offsets, requires_grad)
def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
# Makes a jagged tensor with N constituent tensors with size
# as specified ((S0, S1, S2), D)
max_dim = (starts + lengths).max()
values_tensor = torch.randn(
starts.shape[0],
max_dim.item(),
inner_dim,
requires_grad=requires_grad,
dtype=torch.float64,
)
return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
_check_recompiles(self, fn, inputs1, inputs2, expected_recompiles)
def test_unary_does_not_recompile(self):
nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None)
self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False)
def test_binary_does_not_recompile(self):
def binary(nt1, nt2):
if nt1.shape == nt2.shape:
return nt1 + nt2
else:
return nt1.sin()
# NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
# This causes a recompile later on when it realizes the batch and last dim
# should not always be equal. To avoid that, we use (3, j0, 5) here.
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None)
nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets)
self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False)
def test_binary_recompiles(self):
def binary(nt1, nt2):
if nt1.shape == nt2.shape:
return nt1 + nt2
else:
return nt1.sin()
# Binary recompiles because singleton ints no longer match
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
def _validate_compile(self, fn, arg_fn):
def _gen_grad_outputs(out_val):
if isinstance(out_val, (list, tuple)):
return tuple(torch.ones_like(c) for c in out_val)
else:
return (torch.ones_like(out_val),)
with self.branch_nested_state():
from torch.nested._internal.nested_tensor import _tensor_symint_registry
# Validate that compilation does not modify eager state
registry_before = list(_tensor_symint_registry.items())
count_before = torch.nested._internal.nested_tensor._tensor_id_counter
guards_exported = []
guards_failed = []
def append_guard_export(guards):
for g in guards:
if g.code_list is not None:
guards_exported.append(g.code_list[0])
def append_guard_fail(guards):
guards_failed.extend(guards)
compiled = torch._dynamo.optimize(
nopython=True,
backend="aot_eager",
guard_export_fn=append_guard_export,
guard_fail_fn=append_guard_fail,
)(fn)
registry_after = list(_tensor_symint_registry.items())
count_after = torch.nested._internal.nested_tensor._tensor_id_counter
self.assertEqual(registry_before, registry_after)
self.assertEqual(count_before, count_after)
args = arg_fn()
compile_out = compiled(*args)
compile_grads = []
g_args = [arg for arg in args if arg.requires_grad]
if len(g_args) > 0:
compile_grad_outputs = _gen_grad_outputs(compile_out)
compile_grads = torch.autograd.grad(
compile_out, inputs=g_args, grad_outputs=compile_grad_outputs
)
with self.branch_nested_state():
args = arg_fn()
ref_out = fn(*args)
ref_grads = []
g_args = [arg for arg in args if arg.requires_grad]
if len(g_args) > 0:
ref_grad_outputs = _gen_grad_outputs(ref_out)
ref_grads = torch.autograd.grad(
ref_out, inputs=g_args, grad_outputs=ref_grad_outputs
)
# Validate correctness forward
if isinstance(compile_out, (list, tuple)):
# TODO: Fix assertEqual() to support NJTs so this isn't necessary
self.assertEqual(len(compile_out), len(ref_out))
for c, r in zip(compile_out, ref_out):
self.assertEqualIgnoringNestedInts(c, r)
else:
self.assertEqualIgnoringNestedInts(compile_out, ref_out)
# Validate correctness backward
for compile_grad, ref_grad in zip(compile_grads, ref_grads):
self.assertEqualIgnoringNestedInts(compile_grad, ref_grad)
return guards_exported, guards_failed
# Note: [What kind of guards are involved in nested tensor compilation]
#
# Until we implement UnionFind, dynamic shapes guards are not involved.
# we rely only on dynamo's tensor aliasing guards.
#
# This is possible because dynamo able to generate tensor aliasing guards
# not only for the outer tensor, but also for the inner tensor.
#
# The case where dynamic shapes guards would eventually come into play is
# when my inputs are (1) two non-aliased tensors, but (2) declared as
# equal using a "trust me assert equal" API.
# Note: [Compiling nested tensor global state]
#
# Today there are two pieces of global eager state that NJTs deals with:
# - tensor_id_counter: a global counter that assigns unique ids to tensors
# - tensor_symint_registry: maps tensor to nested int
# - this is used in eager only (we should get rid of this because it is
# not necessary to cache nested int in eager)
# - during tracing, we DO need to cache nested int, but we do so on
# the FakeTensor.
#
# Ideally we would like to satisfy the following:
# - (1) The eager state is not mutated during tracing
# - (2) Running the compiled function should mutate the eager state in the
# same way that running the eager function would
# (a) The global counter should be incremented
# (b) The registry is updated in the same way
#
# Today we can satisfy (1) and (2a) but cannot satisfy (2b)
#
# Today, (1) is satisfied because we maintain a separate counter during
# tracing, and cache nested int on FakeTensor instead of relying on
# tensor_symint_registry.
#
# (2) is cannot be completely satisfied because we trace away the
# side-effectful operations (which we can fix this by wrapping the
# side-effectful operations in a custom op, and threading through effect
# tokens.) The current plan is to do that in the UnionFind impl.
#
# Interestingly, despite this, the state is mutated in a way that is somewhat
# close to what we want, e.g. if I construct a nested tensor using an
# offsets in the compiled region and return it, AOTAutograd runtime wrapper
# must rewrap the inner->inner graph outputs back into subclass. This
# triggers the eager logic to run, updating the counter and registry.
#
# Notably however, compile differs in two ways from eager:
# (1) The order in which the offsets are assigned ids is differnet
# the registry would be set in the order the offsets are returned
# which is not necessarily the same order as they were constructed.
# (2) If a NestedTensor is not returned, then the AOTAutograd wrapping
# logic will not be triggered.
#
# I claim that correctness is not affected by these differences today.
# e.g. there is never the case where two distinct offsets silently share
# the same id.
#
# (1) is clearly not a problem, and (2) should only be a problem if
# the nested int is returned on its own, without the corresponding NJT
# being returned. This is not a problem in the current implementation
# because returning only a shape is not supported!
# Note: [Creating symbolic nested int]
#
# We must create a symbolic nested int when we construct a nested tensor
# from a tensor. There are two main cases:
#
# 1. The offsets has NOT been used to construct a NJT
# - Create a new plain nested int with current val of fake nt id counter
# - Increment the fake nt id counter
# - Create a new symint with plain nested int as hint
# 2. The offsets HAS been used to construct a NJT
# - Create a new symint with plain nested int as hint
#
# More details on case 2:
# - During fakification of the offsets, we check the eager registry, and
# if the tensor HAS been used to construct a NJT,
# we create a symint, with the existing nested int as hint, and cache
# it on to the FakeTensor.
#
# [ Always use ephemeral source ]
#
# We create the new symint ALWAYS with ephemeral source whether that is
# in case (1) or (2) even though we could've had a proper source for case (2).
# Using a proper source would enable a few more (edge) cases, but since
# we plan to handle things more holistically in the future anyway, we don't
# bother doing so today.
#
# Using an ephemeral source has some consequences. But we are happy if
# - We do not silently miss recompiles, e.g. we guard when necessary.
# We know that this is true, because dynamo guards alone are already
# sufficient.
# - We are not producing errors for the cases we care about
#
# The main case we care about is when we guard that two shapes are equal.
# In this case, the replacements logic would simplify away the ephemeral
# symbol, and there is no error produced.
# The unsupported case is when we guard that two shapes are not equal, in
# which, we will try and fail to generate a guard.
#
# Case 1: in-graph construction where the offsets are passed as inputs
#
def test_in_graph_construction_from_input(self):
# The offsets is passed as an input
def fn(values, offsets):
return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
# Do not specialize on the offsets
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
different_offsets = torch.tensor([0, 1, 5, 10], dtype=torch.int64)
self._validate_compile(fn, arg_fn=lambda: (values, different_offsets))
def test_in_graph_construction_from_input_2(self):
# Construct two NJTs, both are passed as inputs
def fn(values, offsets1, offsets2):
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets1)
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
return nt2, nt1
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
# 1. Offsets are different
guards_exported, guards_failed = self._validate_compile(
fn, arg_fn=lambda: (values, offsets, offsets2)
)
self.assertEqual(len(guards_failed), 0)
self.assertNotIn("L['offsets1'] is L['offsets2']", guards_exported)
# TODO
# 2. Offsets are the same
new_guards_exported, _ = self._validate_compile(
fn, arg_fn=lambda: (values, offsets, offsets)
)
self.assertTrue(any("Duplicate tensors found" in g for g in guards_failed))
self.assertIn("L['offsets1'] is L['offsets2']", new_guards_exported)
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
offsets3 = offsets.clone()
self._validate_compile(fn, arg_fn=lambda: (values, offsets3, offsets3))
# Do a binary op
def fn(values, offsets, offsets2):
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
return nt1 * nt2
self._validate_compile(fn, arg_fn=lambda: (values, offsets, offsets))
def test_in_graph_construction_from_input_4(self):
# The offsets is taken from an NJT input
def fn(nt, other_values):
nt2 = torch.nested.nested_tensor_from_jagged(other_values, nt.offsets())
return nt + nt2
values = torch.randn(9, 5, requires_grad=True)
other_values = torch.randn(9, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
def arg_fn(values=values, other_values=other_values, offsets=offsets):
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
return nt, other_values
self._validate_compile(fn, arg_fn=arg_fn)
# Do not specialize on the offsets
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
different_offsets = offsets.clone()
def arg_fn(
values=values, other_values=other_values, offsets=different_offsets
):
nt = torch.nested.nested_tensor_from_jagged(values, different_offsets)
return nt, other_values
self._validate_compile(fn, arg_fn=arg_fn)
def test_in_graph_construction_from_input_5(self):
# Construct from lengths instead of offsets
def fn(values, lengths):
nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)
return nt.sin()
values = torch.randn(9, 5, requires_grad=True)
lengths = torch.tensor([2, 4, 3])
self._validate_compile(fn, arg_fn=lambda: (values, lengths))
#
# Case 2: in-graph construction where offsets are graph intermediates
#
def test_in_graph_construction_from_intermediate(self):
# offsets is an intermediate computed from lengths
def fn(values, lengths):
offsets = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)])
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
return (nt * nt2).sin()
values = torch.randn(9, 5, requires_grad=True)
lengths = torch.tensor([2, 4, 3])
self._validate_compile(fn, arg_fn=lambda: (values, lengths))
# Do not specialize on the lengths
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
different_lengths = lengths.clone()
self._validate_compile(fn, arg_fn=lambda: (values, different_lengths))
def test_in_graph_construction_from_intermediate_2(self):
def fn(values, offsets):
return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
def test_in_graph_construction_from_intermediate_3(self):
# Note that due to CSE, clone is not necessarily called twice!
def fn(values, offsets):
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets.clone())
return nt2, nt1
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
def test_in_graph_construction_from_intermediate_4(self):
# Shared intermediate (should be same as case #1)
def fn(values):
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
values2 = torch.ones_like(values)
nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets)
return nt * nt2
values = torch.randn(10, 5).requires_grad_(True)
self._validate_compile(fn, arg_fn=lambda: (values,))
# AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
@unittest.expectedFailure
def test_in_graph_construction_from_intermediate_5(self):
# non-shared intermediate
def fn(values):
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
values2 = torch.ones_like(values)
nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets.clone())
if nt2.shape[1] != nt.shape[1]:
return nt * 2
else:
return nt * 3
values = torch.randn(10, 5).requires_grad_(True)
self._validate_compile(fn, arg_fn=lambda: (values,))
#
# Case 3: in-graph construction where offsets are both direct graph inputs
# and passed in as part of an NJT's offsets.
#
def test_in_graph_construction_mixed(self):
def fn(nt, values, offsets):
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
return nt * nt2
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
def arg_fn(values=values, offsets=offsets):
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
return nt, values, offsets
self._validate_compile(fn, arg_fn)
# See Note: [Creating symbolic nested int]
# AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
@unittest.expectedFailure
def test_in_graph_construction_mixed_2(self):
def fn(nt, values, offsets, nt2):
# Intermediate offsets has ephemeral source
intermediate_nt = torch.nested.nested_tensor_from_jagged(
values, offsets.clone()
)
# This creates a dynamic shapes neq guard
if nt2.shape[1] != intermediate_nt.shape[1]:
# We should always go here.
nt = nt * 2
return nt
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
def arg_fn(values=values, offsets=offsets, offsets2=offsets2):
# Values is shared, but it shouldn't matter
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets2)
return nt, values, offsets, nt2
self._validate_compile(fn, arg_fn)
def test_in_graph_construction_mixed_3(self):
# More involved mixed case
def fn(nt, values, offsets):
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets)
return nt1 + nt2 + nt
values = torch.randn(9, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
def arg_fn(values=values, offsets=offsets):
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
return nt, values, offsets
self._validate_compile(fn, arg_fn)
def test_return_shape(self):
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
def fn(nt):
return (nt * 2).shape
compiled = torch.compile(fn, fullgraph=True, backend="aot_eager")
compiled(nt)
# TODO: cannot parametrize this test class with device for some reason
def _test_autograd(self, backend):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
# TODO: Switch to public API when it exists
nt2, _ = jagged_from_list([a, b, c], nt.offsets())
def fn1(nt1, nt2):
return (nt1 + nt2).sin().cos()
compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True)
out = compiled_f(nt, nt2)
out_buffer = out.values()
ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))
out_ref = fn1(nt, nt2)
out_buffer_ref = out_ref.values()
ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c))
self.assertTrue(torch.allclose(ga, ga_ref))
self.assertTrue(torch.allclose(gb, gb_ref))
self.assertTrue(torch.allclose(gc, gc_ref))
def test_basic_autograd(self):
self._test_autograd("aot_eager")
@requires_cuda
def test_basic_autograd_inductor(self):
self._test_autograd("inductor")
def test_subclass_with_mutation_in_graph(self):
# In this graph, we have an in-graph mutation, i.e. a mutation that is allowed
# to remain in the graph. Normally this is allowed, but it's not allowed if
# the graph handles subclasses at all.
# Whether the mutation is allowed or not allowed in the graph alters the number
# of outputs from the forward graph. Previously, a bug in this handling meant
# that sometimes the expected number and actual number of outputs from the
# joint graph did not match, causing assertion failures.
def fn(x, y):
z = x.sin()
y.sin_()
return z.cos(), y.cos()
fn_c = torch.compile(fn, backend="inductor")
values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)]
values_copy = [x.detach().clone().requires_grad_(True) for x in values]
nt, offsets = jagged_from_list(values, None)
nt_copy, offsets = jagged_from_list(values_copy, offsets)
y = torch.rand((4, 8))
y_copy = y.clone()
ret = fn_c(nt, y)[0]
ref = fn(nt_copy, y_copy)[0]
self.assertEqual(ret.values(), ref.values())
ret.values().sum().backward()
ref.values().sum().backward()
for ref_v, res_v in zip(values_copy, values):
self.assertEqual(ref_v.grad, res_v.grad)
def test_unbind(self):
# NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
# This causes a recompile later on when it realizes the batch and last dim
# should not always be equal. To avoid that, we use (3, j0, 5) here.
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None)
nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None)
def fn(x):
return x.unbind()
compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True)
out = compiled_f(nt)
out_ref = fn(nt)
# correctness
self.assertEqual(len(out), len(out_ref))
for x, x_ref in zip(out, out_ref):
self.assertTrue(torch.allclose(x, x_ref))
# We specialize on the length of offsets, e.g. (1) we recompile if the
# length of the offsets is different. (2) we don't recompile if the
# length of the offsets is the same, even if the size of the constituent
# tensors are different.
self._check_recompiles(fn, (nt,), (nt2,), False)
self._check_recompiles(fn, (nt,), (nt3,), True)
def test_inline_nested_tensor_from_jagged(self):
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
def fn(x):
return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets())
torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
# The test here: nn.Parameters that are secretly subclasses
# have a metaclass that overrides __isinstance__,
# that dynamo needs to respect when it inlines the if statement.
def test_param_subclass_isinstance_input(self):
x_inner = torch.randn(16, 16, requires_grad=True)
x = torch.nn.Parameter(TwoTensor(x_inner, x_inner))
m = torch.nn.Linear(16, 16)
m.weight = x
def fn():
if isinstance(m.weight, torch.nn.Parameter):
return m.weight + 1
else:
return m.weight + 2
out_ref = fn()
out_test = torch.compile(fn, backend="aot_eager")()
self.assertEqual(out_ref, out_test)
def _input_view_test(self, nt_view_name):
nt_view = VIEW_TEST_CASES[nt_view_name]()
def fn(x):
return x.sin()
out_ref = fn(nt_view)
torch._dynamo.reset()
compile_fn = torch.compile(
fn, fullgraph=True, backend="aot_eager", dynamic=True
)
out = compile_fn(nt_view)
# Check metadata and values are correct
self.assertTrue(out.size() == out_ref.size())
self.assertTrue(out.stride() == out_ref.stride())
if out.is_nested:
self.assertTrue(torch.allclose(out.values(), out_ref.values()))
else:
self.assertTrue(torch.allclose(out, out_ref))
# Check that no upper/lower bound guards are incurred
def backend(gm, args):
context = torch._guards.TracingContext.get()
guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
# varies based on the type of view
guard_str = "\n".join(guards)
if nt_view_name == "subclass_dense":
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
elif nt_view_name == "dense_subclass_dense_subclass":
self.assertExpectedInline(
guard_str,
"""\
Eq(s5 - 1, s2)
Eq(s12 - 1, s7)
Eq(s11, s9)""",
)
elif nt_view_name.startswith("base_is_nt_True"):
self.assertExpectedInline(
guard_str,
"""Eq(s3 - 1, s0)""",
)
else:
self.assertExpectedInline(
guard_str,
"""\
Eq(s4 - 1, s1)
Eq(s13 - 1, s8)
Eq(s12, s10)""",
)
return gm
torch._dynamo.reset()
compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
out = compile_fn(nt_view)
@parametrize(
"nt_view_name",
[k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"],
)
def test_inputs_to_compiled_fn_are_views(self, nt_view_name):
self._input_view_test(nt_view_name)
def test_subclass_gives_static_shapes_when_dynamic_false(self):
def check_graph(gm, *args):
first_node_example_val = next(iter(gm.graph.nodes)).meta["example_value"]
# We compiled with dynamic=False, expect no SymInt sizes on our placeholders
self.assertTrue(
all(isinstance(x, int) for x in first_node_example_val.shape)
)
return gm
@torch.compile(backend=check_graph, dynamic=False)
def f(x):
return x + 1
x_inner = torch.ones(4)
x = TwoTensor(x_inner, x_inner)
x_view = x.view(2, 2)
out = f(x_view)
# NJT1 -> Dense -> NJT2 -> Dense view
# During view replay, the Dense -> NJT2 part will construct an intermediate,
# symbolically-sized NJT that is immediately deconstructed to return the final dense
# view. To construct this intermediate properly, we need the associated nested int
# to be symbolic. This view is expected to fail compilation until symbolic nested ints
# are cached onto fake offsets to solve this problem.
@unittest.expectedFailure
def test_subclass_dense_subclass_dense_view(self):
self._input_view_test("subclass_dense_subclass_dense")
instantiate_parametrized_tests(TestNestedTensor)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()