mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947 Approved by: https://github.com/ezyang
4081 lines
160 KiB
Python
4081 lines
160 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import functools
|
|
import itertools
|
|
import unittest
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch._functorch.config
|
|
import torch.utils._pytree as pytree
|
|
import torch.utils.checkpoint
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm
|
|
from torch._functorch._aot_autograd.utils import make_boxed_compiler
|
|
from torch._functorch.compilers import min_cut_rematerialization_partition
|
|
from torch._higher_order_ops.wrap import wrap
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
DimDynamic,
|
|
ShapeEnv,
|
|
StatelessSymbolicContext,
|
|
)
|
|
from torch.nested._internal.nested_tensor import (
|
|
jagged_from_list,
|
|
jagged_from_tensor_and_lengths,
|
|
nested_view_from_values_offsets,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
NestedTensorTestCase,
|
|
parametrize,
|
|
subtest,
|
|
)
|
|
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
from torch.utils._python_dispatch import return_and_correct_aliasing
|
|
|
|
|
|
def nontraceable_subclass(c):
|
|
return torch._dynamo.config.patch("nontraceable_tensor_subclasses", {c})
|
|
|
|
|
|
def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
|
|
actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2)
|
|
self.assertEqual(actual_recompiles, expected_recompiles)
|
|
|
|
|
|
def get_jagged_tensor(nested_size, offsets, requires_grad=True):
|
|
# Makes a jagged tensor with N constituent tensors with size
|
|
# as specified ((S0, S1, S2), D)
|
|
D = nested_size[1]
|
|
out = []
|
|
for s in nested_size[0]:
|
|
out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64))
|
|
return jagged_from_list(out, offsets)
|
|
|
|
|
|
def get_view_test_cases():
|
|
# Test all cases with both an NT base and a dense base
|
|
# Subclass -> Subclass
|
|
# Dense -> Subclass
|
|
|
|
# NB: Don't close over loop variables, they will not get copied into the
|
|
# closure
|
|
#
|
|
# NB: These return functions so we don't generate tensors during test
|
|
# collection time
|
|
|
|
def mk_basic(base_is_nt):
|
|
# There are three cases to consider here based on the logic in
|
|
# meta_utils.py
|
|
#
|
|
# (1) basic case:
|
|
# view is not a leaf and has the same requires grad as its basic case
|
|
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
|
|
x = x.clone() if base_is_nt else x
|
|
assert not x.is_leaf
|
|
return x.unsqueeze(-1)
|
|
|
|
def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2):
|
|
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1)
|
|
x = x.clone() if base_is_nt else x
|
|
with torch.no_grad():
|
|
x_view = x.unsqueeze(-1)
|
|
# The issue is this doesn't quite work
|
|
x_view.requires_grad_(requires_grad_2)
|
|
|
|
return x_view
|
|
|
|
def mk_obscure(base_is_nt):
|
|
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
|
|
x = x.clone() if base_is_nt else x
|
|
# intermediate leaf view
|
|
with torch.no_grad():
|
|
x_view = x.unsqueeze(-1)
|
|
x_view.requires_grad_(True)
|
|
x_view_view = x_view.unsqueeze(-1)
|
|
return x_view_view
|
|
|
|
for base_is_nt in [False, True]:
|
|
prefix = f"base_is_nt_{base_is_nt}"
|
|
|
|
yield partial(mk_basic, base_is_nt), f"{prefix}_basic"
|
|
|
|
# (2) leaf view case:
|
|
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
|
|
# base w/ requires_grad True or requires_grad False
|
|
for requires_grad_1, requires_grad_2 in itertools.product(
|
|
[True, False], repeat=2
|
|
):
|
|
yield (
|
|
partial(mk_leaf, base_is_nt, requires_grad_1, requires_grad_2),
|
|
f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}",
|
|
)
|
|
|
|
# (3) obscure case:
|
|
# view is not a leaf (implies requires_grad True)
|
|
# base w/ requires_grad False)
|
|
yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
|
|
|
|
# Subclass -> Dense
|
|
yield (
|
|
lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone(),
|
|
"subclass_dense",
|
|
)
|
|
|
|
# Dense -> Subclass -> Dense -> Subclass
|
|
def mk_dense_subclass_dense_subclass():
|
|
values = torch.randn(10, 5)
|
|
offsets = torch.tensor([0, 3, 6, 10])
|
|
return nested_view_from_values_offsets(
|
|
nested_view_from_values_offsets(values, offsets).values(), offsets
|
|
)
|
|
|
|
yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass"
|
|
|
|
def mk_subclass_dense_subclass_dense():
|
|
x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
|
|
offsets2 = x.offsets().detach().clone()
|
|
nested_view_from_values_offsets(x.values(), offsets2).values()
|
|
|
|
yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense"
|
|
|
|
|
|
VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()}
|
|
|
|
|
|
compile_full_eager = torch.compile(backend="eager", fullgraph=True)
|
|
|
|
|
|
class BaseTorchFunction(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
|
|
class MockSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
|
|
class AttrSubclass(torch.Tensor):
|
|
x: int = 10
|
|
size: int = 10
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
|
|
class DummyNDim(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if func == torch.Tensor.ndim.__get__:
|
|
return 10
|
|
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
|
|
class WrapperSubclass:
|
|
def __init__(self, tensor):
|
|
self.tensor = tensor
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args)
|
|
kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs)
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
class SigmoidToExpSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if func == torch.Tensor.sigmoid:
|
|
return super().__torch_function__(torch.Tensor.exp, types, args, kwargs)
|
|
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
|
|
# Wrapper subclass with two inner tensors: data and scale
|
|
# data has same shape as outer, and scale has single dim size
|
|
class ScaledTensor(torch.Tensor):
|
|
def __new__(
|
|
cls,
|
|
data: torch.Tensor,
|
|
scale: torch.Tensor,
|
|
*,
|
|
constant: int = 0,
|
|
):
|
|
return torch.Tensor._make_wrapper_subclass(
|
|
cls,
|
|
data.size(),
|
|
strides=data.stride(),
|
|
storage_offset=data.storage_offset(),
|
|
dtype=data.dtype,
|
|
layout=data.layout,
|
|
requires_grad=data.requires_grad,
|
|
device=data.device,
|
|
)
|
|
|
|
def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0):
|
|
self._data = data
|
|
self._scale = scale
|
|
self._constant = constant
|
|
|
|
def __tensor_flatten__(self):
|
|
ctx = {"_constant": self._constant}
|
|
return ["_data", "_scale"], ctx
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
|
|
assert len(inner_tensors) == 2
|
|
return ScaledTensor(
|
|
inner_tensors["_data"],
|
|
inner_tensors["_scale"],
|
|
constant=metadata["_constant"],
|
|
)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs=None):
|
|
scaled_tensor = args[0]
|
|
out = func(scaled_tensor._data, *args[1:], **kwargs)
|
|
return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant)
|
|
|
|
def __repr__(self):
|
|
return f"{self._data.__repr__()}\n{self._scale.__repr__()}"
|
|
|
|
|
|
class OptionalScaledTensor(torch.Tensor):
|
|
def __new__(
|
|
cls,
|
|
data,
|
|
scale,
|
|
*,
|
|
constant: int = 0,
|
|
):
|
|
return torch.Tensor._make_wrapper_subclass(
|
|
cls,
|
|
data.size(),
|
|
strides=data.stride(),
|
|
storage_offset=data.storage_offset(),
|
|
dtype=data.dtype,
|
|
layout=data.layout,
|
|
requires_grad=data.requires_grad,
|
|
device=data.device,
|
|
)
|
|
|
|
def __init__(self, data: torch.Tensor, scale, constant: int = 0):
|
|
self._data = data
|
|
self._scale = scale
|
|
self._constant = constant
|
|
|
|
def __tensor_flatten__(self):
|
|
ctx = {"_constant": self._constant}
|
|
if self._scale is not None:
|
|
return ["_data", "_scale"], ctx
|
|
else:
|
|
return ["_data"], ctx
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
|
|
return OptionalScaledTensor(
|
|
inner_tensors["_data"],
|
|
inner_tensors["_scale"] if "_scale" in inner_tensors else None,
|
|
constant=metadata["_constant"],
|
|
)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs=None):
|
|
scaled_tensor = args[0]
|
|
out = func(scaled_tensor._data, *args[1:], **kwargs)
|
|
if scaled_tensor._scale is not None:
|
|
out = out * scaled_tensor._scale
|
|
return OptionalScaledTensor(
|
|
out, scaled_tensor._scale, constant=scaled_tensor._constant
|
|
)
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"OptionalScaledTensor({self._data.__repr__()}\n{self._scale.__repr__()})"
|
|
)
|
|
|
|
|
|
class CtxSubclassTensor(torch.Tensor):
|
|
"""
|
|
Class used to verify guarding on the subclass metadata
|
|
"""
|
|
|
|
@staticmethod
|
|
def __new__(cls, a, constant):
|
|
shape = a.shape
|
|
kwargs = {}
|
|
kwargs["strides"] = a.stride()
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
return out
|
|
|
|
def __init__(self, a, constant):
|
|
self.a = a
|
|
self.constant = constant
|
|
|
|
def __repr__(self):
|
|
a_repr = repr(self.a)
|
|
return f"CtxSubclassTensor({a_repr})"
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["a"], (self.constant,)
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, sizes, strides):
|
|
constant = meta[0]
|
|
a = inner_tensors["a"]
|
|
return CtxSubclassTensor(a, constant)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
biggest_constant = max(
|
|
[
|
|
x.constant
|
|
for x in pytree.tree_flatten(args)[0]
|
|
if isinstance(x, CtxSubclassTensor)
|
|
]
|
|
)
|
|
args_a = pytree.tree_map(
|
|
lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, args
|
|
)
|
|
kwargs_a = pytree.tree_map(
|
|
lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, kwargs
|
|
)
|
|
out_a = func(*args_a, **kwargs_a)
|
|
out = pytree.tree_map(
|
|
lambda x: (
|
|
CtxSubclassTensor(x, biggest_constant)
|
|
if isinstance(x, torch.Tensor)
|
|
else x
|
|
),
|
|
out_a,
|
|
)
|
|
|
|
if func == torch.ops.aten.mul.Tensor:
|
|
out = out + out.constant
|
|
|
|
return return_and_correct_aliasing(func, args, kwargs, out)
|
|
|
|
|
|
def func(a):
|
|
return a.sin()
|
|
|
|
|
|
class EagerRecordGraphAndInputs:
|
|
def __init__(self) -> None:
|
|
self.graphs = []
|
|
self.example_inputs = []
|
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
|
self.graphs.append(gm)
|
|
self.example_inputs.append(example_inputs)
|
|
return gm
|
|
|
|
|
|
GLOBAL_TEST_SUBCLASSES = {
|
|
MockSubclass,
|
|
DummyNDim,
|
|
SigmoidToExpSubclass,
|
|
BaseTorchFunction,
|
|
}
|
|
|
|
|
|
# Returns True if the function recompiles between inputs1 and inputs2 with the
|
|
# specified dynamic setting.
|
|
def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
|
|
compile_count = [0]
|
|
|
|
def counter(gm, example_inputs):
|
|
compile_count[0] += 1
|
|
return gm
|
|
|
|
compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
|
|
compiled_f(*inputs1)
|
|
compiled_f(*inputs2)
|
|
return compile_count[0] > 1
|
|
|
|
|
|
class SubclassTests(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls._exit_stack.close()
|
|
|
|
def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
|
|
_check_recompiles(self, fn, inputs1, inputs2, expected_recompiles)
|
|
|
|
def test_no_call_to_new(self):
|
|
class BadNewTorchFunction(torch.Tensor):
|
|
def __new__(cls, *args, **kwargs):
|
|
raise RuntimeError("Oops!")
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return torch.add(x, 1)
|
|
|
|
input = torch.ones(2, 2).as_subclass(BadNewTorchFunction)
|
|
|
|
res = fn(input)
|
|
self.assertIsInstance(res, BadNewTorchFunction)
|
|
|
|
def test_no_torch_function_recompiles(self):
|
|
class NJT:
|
|
def __repr__(self):
|
|
return f"NJT(shape={self.shape})"
|
|
|
|
def __init__(self, values, offsets):
|
|
self._values = values
|
|
self._offsets = offsets
|
|
|
|
def sin(self):
|
|
return torch.sin(self)
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if func == torch.sin:
|
|
self = args[0]
|
|
return NJT(func(self._values), self._offsets)
|
|
raise AssertionError("should not get here")
|
|
|
|
values1 = torch.randn(10, 3, 4, requires_grad=True)
|
|
values2 = torch.randn(10, 3, 4, requires_grad=True)
|
|
offsets = torch.tensor([0, 3, 10])
|
|
njt1 = NJT(values1, offsets)
|
|
njt2 = NJT(values2, offsets)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
f(njt1)
|
|
f(njt2)
|
|
|
|
def test_base_torch_function_tracing(self):
|
|
def fn(x):
|
|
return torch.add(x, 1)
|
|
|
|
input = torch.ones(2, 2).as_subclass(BaseTorchFunction)
|
|
out = fn(input)
|
|
out_opt = compile_full_eager(fn)(input)
|
|
self.assertIsInstance(out, BaseTorchFunction)
|
|
self.assertEqual(out, out_opt)
|
|
|
|
def test_torch_function_state_graph_break(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
torch._dynamo.graph_break()
|
|
return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
|
|
|
|
input = torch.ones(2, 2)
|
|
res, _ = fn(input)
|
|
self.assertFalse(res)
|
|
|
|
def test_disable_all_torch_function(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
with torch._C.DisableTorchFunction():
|
|
torch._dynamo.graph_break()
|
|
return (
|
|
torch._C._is_torch_function_enabled(),
|
|
torch._C._is_torch_function_all_disabled(),
|
|
torch.add(x, 1.0),
|
|
)
|
|
|
|
input = torch.ones(2, 2)
|
|
res1, res2, _ = fn(input)
|
|
self.assertFalse(res1)
|
|
self.assertTrue(res2)
|
|
|
|
def test_disable_all_torch_function_restore_values(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
with torch._C.DisableTorchFunction():
|
|
x = torch._C._is_torch_function_all_disabled()
|
|
|
|
return (
|
|
x,
|
|
torch._C._is_torch_function_all_disabled(),
|
|
torch.add(x, 1.0),
|
|
)
|
|
|
|
input = torch.ones(2, 2)
|
|
res1, res2, _ = fn(input)
|
|
self.assertTrue(res1)
|
|
self.assertFalse(res2)
|
|
|
|
def test_disable_all_torch_function_restore_values_graph_break(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
with torch._C.DisableTorchFunction():
|
|
torch._dynamo.graph_break()
|
|
x = torch._C._is_torch_function_all_disabled()
|
|
|
|
return (
|
|
x,
|
|
torch._C._is_torch_function_all_disabled(),
|
|
torch.add(x, 1.0),
|
|
)
|
|
|
|
input = torch.ones(2, 2)
|
|
res1, res2, _ = fn(input)
|
|
self.assertTrue(res1)
|
|
self.assertFalse(res2)
|
|
|
|
def test_torch_function_state_nested(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
x = x + 1
|
|
# Should reset to the outer state (disabled) after exiting ctx manager
|
|
return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
|
|
|
|
input = torch.ones(2, 2)
|
|
res, _ = fn(input)
|
|
self.assertFalse(res)
|
|
|
|
def test_torch_function_state_tracing(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
torch.add(x, 1.0)
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
fn(input)
|
|
|
|
def test_torch_function_state_guards(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(x):
|
|
torch.add(x, 1.0)
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
fn(input)
|
|
|
|
fn(input)
|
|
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_return_subclass(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return MockSubclass(torch.add(x, 1.0)) * 2
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
res = fn(input)
|
|
self.assertIsInstance(res, MockSubclass)
|
|
|
|
def test_return_as_subclass(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return torch.add(x, 1.0).as_subclass(MockSubclass) * 2
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
res = fn(input)
|
|
self.assertIsInstance(res, MockSubclass)
|
|
|
|
def test_return_local_subclass(self):
|
|
class LocalSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return LocalSubclass(torch.add(x, 1.0)) * 2
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
res = fn(input)
|
|
self.assertIsInstance(res, LocalSubclass)
|
|
|
|
def test_torch_function_list_args(self):
|
|
HANDLED_FUNCTIONS = {}
|
|
|
|
class MyClass:
|
|
def __init__(self, foo):
|
|
self.foo = foo
|
|
|
|
@classmethod
|
|
def __torch_function__(
|
|
cls,
|
|
func,
|
|
types,
|
|
args=(),
|
|
kwargs=None,
|
|
):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if func not in HANDLED_FUNCTIONS or not all( # noqa: C419
|
|
[ # noqa: C419
|
|
issubclass(t, (torch.Tensor, MyClass)) for t in types
|
|
]
|
|
):
|
|
return NotImplemented
|
|
return HANDLED_FUNCTIONS[func](*args, **kwargs)
|
|
|
|
def _stack(input, dim=0, *, out=None):
|
|
return MyClass(sum([x.foo for x in input]))
|
|
|
|
HANDLED_FUNCTIONS[torch.stack] = _stack
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(v0, v1):
|
|
return torch.stack([v0, v1])
|
|
|
|
ret = fn(MyClass(1), MyClass(1))
|
|
self.assertEqual(ret.foo, 2)
|
|
|
|
@parametrize(
|
|
"comparison",
|
|
[
|
|
subtest(isinstance, "isinstance"),
|
|
subtest(lambda instance, type_: type(instance) == type_, "equality"),
|
|
subtest(lambda instance, type_: type(instance) is type_, "identity"),
|
|
],
|
|
)
|
|
@parametrize(
|
|
"input_type",
|
|
[
|
|
subtest(torch.Tensor, "tensor"),
|
|
subtest(DummyNDim, "subclass"),
|
|
],
|
|
)
|
|
def test_type_check(self, comparison, input_type):
|
|
def fn(x):
|
|
if comparison(x, DummyNDim):
|
|
return torch.ones(1, 1)
|
|
else:
|
|
return torch.zeros(2, 2)
|
|
|
|
input = torch.ones(2, 2).as_subclass(input_type)
|
|
exp_res = fn(input)
|
|
act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input)
|
|
self.assertEqual(exp_res, act_res)
|
|
|
|
def test_torch_function_call_on_method(self):
|
|
x = torch.ones(2, 2)
|
|
y = torch.ones(2, 2)
|
|
z = torch.ones(2, 2)
|
|
wrapped = x.as_subclass(SigmoidToExpSubclass)
|
|
wrapped2 = y.as_subclass(SigmoidToExpSubclass)
|
|
|
|
def fn(w):
|
|
return w.exp()
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(wrapped)
|
|
res_act = fn_opt(wrapped2)
|
|
res_exp2 = z.exp()
|
|
|
|
self.assertEqual(res_exp, res_act)
|
|
self.assertEqual(res_exp, res_exp2)
|
|
|
|
def test_torch_function_call_on_method_arg(self):
|
|
class LocalSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if func == torch._C.TensorBase.add_:
|
|
func = torch._C.TensorBase.sub_
|
|
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
def sigmoid(self):
|
|
return None
|
|
|
|
x = torch.ones(2, 2)
|
|
y = torch.ones(2, 2)
|
|
z = torch.ones(2, 2)
|
|
wrapped = y.as_subclass(LocalSubclass)
|
|
wrapped2 = z.as_subclass(LocalSubclass)
|
|
|
|
def fn(a, w):
|
|
a.add_(w)
|
|
return a
|
|
|
|
fn_opt = torch.compile(fn)
|
|
|
|
res_exp = fn(x, wrapped)
|
|
res_act = fn_opt(y, wrapped2)
|
|
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_user_overridden_method_unsupported(self):
|
|
class LocalSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
def sigmoid(self):
|
|
return None
|
|
|
|
def fn(x):
|
|
x.sigmoid()
|
|
|
|
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(x)
|
|
res_act = fn_opt(x)
|
|
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_user_overridden_attr_unsupported(self):
|
|
class LocalSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
ndim = 10
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return x.ndim
|
|
|
|
msg = "`torch.compile` only support tracing certain types of overridden tensor subclass attributes"
|
|
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
|
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
|
fn(x)
|
|
|
|
def test_user_overridden_property_unsupported(self):
|
|
class LocalSubclass(torch.Tensor):
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
self._ndim = 10
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
@property
|
|
def ndim(self):
|
|
return self._ndim
|
|
|
|
@ndim.setter
|
|
def ndim(self, value):
|
|
self._ndim = value
|
|
|
|
def fn(x):
|
|
return x + x.ndim
|
|
|
|
x = LocalSubclass(torch.ones(2, 2))
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(x)
|
|
res_act = fn_opt(x)
|
|
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_overridden_method_guarding(self):
|
|
class LocalSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return x.sigmoid()
|
|
|
|
with torch._dynamo.config.patch(error_on_recompile=True):
|
|
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
|
fn(x)
|
|
fn(x)
|
|
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
|
fn(x)
|
|
|
|
with self.assertRaisesRegex(TypeError, "'bool' object is not callable"):
|
|
LocalSubclass.sigmoid = False
|
|
fn(x)
|
|
|
|
def test_torch_function_call_on_attr(self):
|
|
x = torch.ones(2, 2)
|
|
wrapped = x.as_subclass(DummyNDim)
|
|
|
|
def fn(w):
|
|
return w.ndim + torch.ones(2)
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(wrapped)
|
|
res_act = fn_opt(wrapped)
|
|
|
|
self.assertEqual(res_exp, res_act)
|
|
self.assertEqual(res_exp, torch.ones(2) + 10)
|
|
|
|
def test_torch_function_wrapper_class(self):
|
|
x = torch.ones(2, 2)
|
|
wrapped = WrapperSubclass(x)
|
|
|
|
def fn(w):
|
|
return torch.add(w, 1.0)
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(wrapped)
|
|
res_act = fn_opt(wrapped)
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_no_torch_function_on_size_bytecode(self):
|
|
class TestTensor(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
out = func(*args, **kwargs)
|
|
|
|
if func == torch.clone:
|
|
return out * 2
|
|
else:
|
|
return out
|
|
|
|
def fn(x):
|
|
return torch.clone(x)
|
|
|
|
inp = torch.ones(4, 4)
|
|
x = inp.as_subclass(TestTensor)
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
compiled_fn = torch.compile(fn, fullgraph=True)
|
|
out = compiled_fn(x)
|
|
self.assertEqual(out, torch.ones(4, 4) * 2)
|
|
|
|
def test_torch_function_wrapper_class_with_kwargs(self):
|
|
x = torch.ones(2, 2)
|
|
wrapped = WrapperSubclass(x)
|
|
|
|
def fn(w):
|
|
return torch.add(w, 1.0, alpha=2.0)
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(wrapped)
|
|
res_act = fn_opt(wrapped)
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_tensor_subclass_with_non_classmethod_torch_function(self):
|
|
class MySubclass(torch.Tensor):
|
|
def __torch_function__(self, func, types, args, kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
return func(*args, **kwargs)
|
|
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
x = torch.randn(2, 2).as_subclass(MySubclass)
|
|
res_exp = fn(x)
|
|
res_act = fn_opt(x)
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_tensor_subclass_custom_attr(self):
|
|
class AttrSubclass(torch.Tensor):
|
|
x: int = 10
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return x.x + torch.ones(2, 2)
|
|
|
|
input = torch.ones(2, 2).as_subclass(AttrSubclass)
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(input)
|
|
res_act = fn_opt(input)
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_make_subclass(self):
|
|
# Make sure `torch.Tensor._make_subclass` is traceable, and Dynamo
|
|
# models its aliasing relationships correctly.
|
|
class MySubclass(torch.Tensor):
|
|
pass
|
|
|
|
def fn(x):
|
|
# Downcast then upcast
|
|
y = torch.Tensor._make_subclass(MySubclass, x)
|
|
z = torch.Tensor._make_subclass(torch.Tensor, x)
|
|
# Now `x, y, z` should have the same underlying data.
|
|
x += 1
|
|
y += 2
|
|
z += 3
|
|
res = x * y + z
|
|
return res
|
|
|
|
x0 = torch.randn(2, 2)
|
|
x1 = x0.clone()
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(x0)
|
|
res_act = fn_opt(x1)
|
|
self.assertEqual(res_exp, res_act)
|
|
self.assertEqual(x0, x1)
|
|
|
|
def test_subclass_override_shape_and_to(self):
|
|
# This is a slight variabtion of
|
|
# https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435
|
|
class MySubclass(torch.Tensor):
|
|
def to(self, *args, **kwargs):
|
|
new = super().to(*args, **kwargs)
|
|
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
|
|
return new
|
|
|
|
@property
|
|
def shape(self):
|
|
if not hasattr(self, "tensor_shape"):
|
|
self.tensor_shape = self.size()
|
|
return self.tensor_shape
|
|
|
|
def fn(x):
|
|
x_shape = x.shape
|
|
y = x.to("cpu")
|
|
return x + 1, y + 2, x_shape, x.tensor_shape, y.tensor_shape
|
|
|
|
x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
|
x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass))
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(x0)
|
|
res_act = fn_opt(x1)
|
|
self.assertEqual(res_exp, res_act)
|
|
self.assertEqual(x0, x1)
|
|
self.assertEqual(x0.tensor_shape, x1.tensor_shape)
|
|
|
|
def test_subclass_dont_invoke_torch_function_on_overridden_method(self):
|
|
# We shouldn't fire `__torch_function__` for overridden tensor methods.
|
|
class MySubclass(torch.Tensor):
|
|
def to(self, device):
|
|
return self * len(device)
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if func is torch.Tensor.to:
|
|
torch._dynamo.graph_break()
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
def fn(x):
|
|
return x.to("cpu")
|
|
|
|
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(x)
|
|
res_act = fn_opt(x)
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_subclass_dont_invoke_torch_function_on_overridden_attr(self):
|
|
from types import MethodWrapperType
|
|
|
|
# We shouldn't fire `__torch_function__` for overridden tensor attrs.
|
|
class MySubclass(torch.Tensor):
|
|
def ndim(self):
|
|
return 42
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if type(func) is MethodWrapperType and func.__name__ == "ndim":
|
|
torch._dynamo.graph_break()
|
|
return super().__torch_function__(func, types, args, kwargs)
|
|
|
|
def fn(x):
|
|
return x + x.ndim()
|
|
|
|
x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass))
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
res_exp = fn(x)
|
|
res_act = fn_opt(x)
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_parameter_subclass_with_old_torch_function(self):
|
|
class MySubclass(torch.nn.Parameter):
|
|
pass
|
|
|
|
def fn(x):
|
|
x = x.t()
|
|
x = x.T
|
|
return x + 1
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
x = torch.randn(2, 2).as_subclass(MySubclass)
|
|
res_exp = fn(x)
|
|
res_act = fn_opt(x)
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_subclass_with_disabled_torch_function(self):
|
|
class MySubclass(torch.Tensor):
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
def fn(x):
|
|
x = x.t()
|
|
x = x.T
|
|
return x + 1
|
|
|
|
fn_opt = compile_full_eager(fn)
|
|
|
|
x = torch.randn(2, 2).as_subclass(MySubclass)
|
|
res_exp = fn(x)
|
|
res_act = fn_opt(x)
|
|
self.assertEqual(res_exp, res_act)
|
|
|
|
def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self):
|
|
# This is a slight variation of
|
|
# https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435
|
|
# which basically
|
|
# 1. uses tensor subclass to attach quantization metadata onto tensors
|
|
# 2. preserve them across torch ops
|
|
# 3. use the metadata to dequantize the tensor
|
|
# 4. convert it to a regular tensor.
|
|
#
|
|
# The test is meant to make sure Dynamo won't graph break over it.
|
|
class GGUFParameter(torch.nn.Parameter):
|
|
def __new__(cls, data, requires_grad=False, quant_type=None):
|
|
data = data if data is not None else torch.empty(0)
|
|
self = torch.Tensor._make_subclass(cls, data, requires_grad)
|
|
return self
|
|
|
|
def __init__(self, *args, quant_type=None, **kwargs):
|
|
self.quant_type = quant_type
|
|
|
|
def as_tensor(self):
|
|
return torch.Tensor._make_subclass(
|
|
torch.Tensor, self, self.requires_grad
|
|
)
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
result = super().__torch_function__(func, types, args, kwargs)
|
|
|
|
quant_type = None
|
|
for arg in args:
|
|
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
|
|
quant_type = arg[0].quant_type
|
|
break
|
|
if isinstance(arg, GGUFParameter):
|
|
quant_type = arg.quant_type
|
|
break
|
|
if isinstance(result, torch.Tensor):
|
|
return cls(result, quant_type=quant_type)
|
|
# Handle tuples and lists
|
|
elif isinstance(result, (tuple, list)):
|
|
# Preserve the original type (tuple or list)
|
|
wrapped = [
|
|
(
|
|
cls(x, quant_type=quant_type)
|
|
if isinstance(x, torch.Tensor)
|
|
else x
|
|
)
|
|
for x in result
|
|
]
|
|
return type(result)(wrapped)
|
|
else:
|
|
return result
|
|
|
|
def f(x):
|
|
tmp = x * 2
|
|
tmp = tmp + tmp.quant_type
|
|
tmp = tmp.as_tensor()
|
|
return tmp * 3
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
|
|
x = GGUFParameter(torch.ones(2), quant_type=42)
|
|
res = f(x)
|
|
ref = opt_f(x)
|
|
self.assertEqual(res, ref)
|
|
|
|
def test_newly_constructed_tensor_subclass_attr_mutation(self):
|
|
# Make sure the attribute mutation for newly constructed tensor subclass
|
|
# object (from constructor call) is handled both during Dynamo tracing
|
|
# and codegen-ed to be visible outside `torch.compile`.
|
|
class MySubclass(torch.Tensor):
|
|
pass
|
|
|
|
def f():
|
|
x = MySubclass(torch.ones(2))
|
|
x.bar = 42
|
|
return x, x * x.bar
|
|
|
|
opt_f = compile_full_eager(f)
|
|
|
|
res = f()
|
|
ref = opt_f()
|
|
|
|
self.assertEqual(res, ref)
|
|
self.assertEqual(res[0].bar, ref[0].bar)
|
|
|
|
def test_as_subclass_attr_mutation(self):
|
|
# Make sure the attribute mutation for newly constructed tensor subclass
|
|
# object (from as_subclass call) is handled both during Dynamo tracing
|
|
# and codegen-ed to be visible outside `torch.compile`.
|
|
class MySubclass(torch.Tensor):
|
|
pass
|
|
|
|
def f():
|
|
x = torch.ones(2).as_subclass(MySubclass)
|
|
x.bar = 42
|
|
return x, x * x.bar
|
|
|
|
opt_f = compile_full_eager(f)
|
|
|
|
res = f()
|
|
ref = opt_f()
|
|
|
|
self.assertEqual(res, ref)
|
|
self.assertEqual(res[0].bar, ref[0].bar)
|
|
|
|
def test_tensor_subclass_attr_codegen_tos(self):
|
|
# This repros a very subtle interaction between
|
|
# `TensorWithTFOverrideVariable` attribute mutation codegen and
|
|
# `PyCodegen.top_of_stack`. It was uncovered from
|
|
# `test_tensor_subclass_deepcopy`.
|
|
class MySubclass(torch.Tensor):
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
r = torch.Tensor._make_subclass(cls, torch.ones(0))
|
|
r.elem = elem
|
|
return r
|
|
|
|
def f(t):
|
|
return MySubclass(t.elem.clone())
|
|
|
|
opt_f = compile_full_eager(f)
|
|
|
|
t = MySubclass(torch.ones(2))
|
|
res = f(t)
|
|
ref = opt_f(t)
|
|
|
|
self.assertEqual(res, ref)
|
|
self.assertEqual(res.elem, ref.elem)
|
|
self.assertEqual(type(res), type(ref))
|
|
|
|
def test_nontraceable_tensor_subclass(self):
|
|
# This will error if Dynamo tries to wrap it as a tensor variable,
|
|
# because that involves calling certain methods to inspect the tensor
|
|
# property, which will blow up in the overridden `__torch_function__`.
|
|
class MySubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
raise RuntimeError("one shall not pass")
|
|
|
|
def f(t):
|
|
return t.foo + torch.ones(10)
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=False)
|
|
|
|
t = MySubclass(torch.ones(2))
|
|
t.foo = 42
|
|
# Make sure the `nontraceable_tensor_subclasses` config prevents Dynamo
|
|
# from wrapping `t`.
|
|
with nontraceable_subclass(MySubclass):
|
|
res = f(t)
|
|
ref = opt_f(t)
|
|
|
|
self.assertEqual(res, ref)
|
|
|
|
def test_compile_with_fake_tensor_dynamic_dim(self):
|
|
x = torch.randn([3, 4])
|
|
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count):
|
|
torch._dynamo.reset()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
|
|
|
|
x1 = torch.rand_like(x)
|
|
f(x)
|
|
f(torch.randn([4, 3]))
|
|
shape_env = ShapeEnv()
|
|
with torch._subclasses.fake_tensor.FakeTensorMode(
|
|
shape_env=shape_env
|
|
) as fake_mode:
|
|
x_fake = fake_mode.from_tensor(
|
|
x,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=[dim_dynamic for i in range(x.dim())]
|
|
),
|
|
)
|
|
x1_fake = fake_mode.from_tensor(
|
|
x1,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=[dim_dynamic for i in range(x.dim())]
|
|
),
|
|
)
|
|
opt_f(x_fake)
|
|
opt_f(x1_fake)
|
|
|
|
self.assertEqual(cnt.frame_count, exp_frame_count)
|
|
self.assertEqual(cnt.op_count, exp_op_count)
|
|
|
|
test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1)
|
|
test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1)
|
|
test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1)
|
|
|
|
def test_compile_with_fake_tensor_automatic_dynamic(self):
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count):
|
|
torch._dynamo.reset()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
|
|
|
|
shape_env = ShapeEnv()
|
|
with torch._subclasses.fake_tensor.FakeTensorMode(
|
|
shape_env=shape_env
|
|
) as fake_mode:
|
|
for inp in inps:
|
|
fake_inp = fake_mode.from_tensor(
|
|
inp,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
[dim_dynamic for i in range(x.dim())]
|
|
),
|
|
)
|
|
opt_f(fake_inp)
|
|
self.assertEqual(cnt.frame_count, exp_frame_count)
|
|
self.assertEqual(cnt.op_count, exp_op_count)
|
|
|
|
x = torch.randn([3, 4])
|
|
y = torch.randn([4, 5])
|
|
z = torch.randn([5, 6])
|
|
a = torch.randn([3, 5])
|
|
b = torch.randn([4, 4])
|
|
# When inputs' DimDynamic is DYNAMIC or DUCK, the inputs
|
|
# to opt_f will be tensors with SymInt sizes. Dynamo will treat input
|
|
# as dynamic automatically and will only compile once
|
|
for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]:
|
|
test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1)
|
|
test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1)
|
|
test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1)
|
|
|
|
for dim_dynamic in [DimDynamic.STATIC]:
|
|
# Recompile once, first with dim 0 and 1 become Dynamic
|
|
test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2)
|
|
# Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic.
|
|
test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3)
|
|
# Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic.
|
|
test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3)
|
|
|
|
def test_compile_with_functionalization(self):
|
|
x = torch.randn([3, 4])
|
|
x_clone = x.clone()
|
|
x_clone2 = x.clone()
|
|
backend = EagerRecordGraphAndInputs()
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return x.add_(1.0) + torch.nn.functional.relu_(x)
|
|
|
|
f_out = f(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertEqual(len(backend.example_inputs), 1)
|
|
|
|
actual = normalize_gm(backend.graphs[0].print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
|
|
add_: "f32[3, 4]" = l_x_.add_(1.0)
|
|
relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None
|
|
add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
|
|
ff = torch.func.functionalize(f)
|
|
ff_out = ff(x_clone)
|
|
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 6)
|
|
self.assertEqual(len(backend.graphs), 2)
|
|
self.assertEqual(len(backend.example_inputs), 2)
|
|
actual = normalize_gm(backend.graphs[1].print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
|
|
add_: "f32[3, 4]" = l_x_.add_(1.0)
|
|
relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None
|
|
add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
|
|
|
# Cannot reuse the version from AOTAutograd, since that uses python functional tensors.
|
|
def to_fun(x):
|
|
x_functional = torch._to_functional_tensor(x)
|
|
torch._mirror_autograd_meta_to(x, x_functional)
|
|
return x_functional
|
|
|
|
def aot_f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
func_args = pytree.tree_map(to_fun, args)
|
|
func_kwargs = pytree.tree_map(to_fun, kwargs)
|
|
return func(*func_args, **func_kwargs)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
aot_ff = aot_f_wrapper(f)
|
|
aot_ff_out = aot_ff(x_clone2)
|
|
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(cnt.op_count, 9)
|
|
self.assertEqual(len(backend.graphs), 3)
|
|
self.assertEqual(len(backend.example_inputs), 3)
|
|
actual = normalize_gm(backend.graphs[2].print_readable(print_output=False))
|
|
self.assertExpectedInline(
|
|
actual,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
|
|
add_: "f32[3, 4]" = l_x_.add_(1.0)
|
|
relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None
|
|
add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None
|
|
return (add,)
|
|
""",
|
|
)
|
|
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
|
|
|
self.assertEqual(f_out, ff_out)
|
|
self.assertEqual(f_out, aot_ff_out)
|
|
|
|
try:
|
|
torch._enable_functionalization(reapply_views=False)
|
|
xf = pytree.tree_map(to_fun, x)
|
|
x_view = xf.t()
|
|
with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"):
|
|
f(x_view)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
def test_compile_higher_order_with_functionalization(self):
|
|
backend = EagerRecordGraphAndInputs()
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return wrap(lambda x: x.add_(1.0), x)
|
|
|
|
def check_count_and_graph(
|
|
exp_frame_count, exp_op_count, exp_n_graph, exp_graph
|
|
):
|
|
self.assertEqual(cnt.frame_count, exp_frame_count)
|
|
self.assertEqual(cnt.op_count, exp_op_count)
|
|
self.assertEqual(len(backend.graphs), exp_n_graph)
|
|
actual = normalize_gm(
|
|
backend.graphs[exp_n_graph - 1].print_readable(print_output=False)
|
|
)
|
|
self.assertExpectedInline(actual, exp_graph, skip=1)
|
|
|
|
t = torch.randn([3, 4])
|
|
t_clone = t.clone()
|
|
t_clone2 = t.clone()
|
|
f(t)
|
|
|
|
check_count_and_graph(
|
|
1,
|
|
2,
|
|
1,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
getitem: "f32[3, 4]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3, 4]"):
|
|
add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None
|
|
return (add_,)
|
|
""",
|
|
)
|
|
|
|
ff = torch.func.functionalize(f)
|
|
ff_out = ff(t_clone) # noqa: F841
|
|
# frame count and op count are incremented due to re-compilation
|
|
check_count_and_graph(
|
|
2,
|
|
4,
|
|
2,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
getitem: "f32[3, 4]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3, 4]"):
|
|
add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None
|
|
return (add_,)
|
|
""",
|
|
)
|
|
|
|
try:
|
|
x = torch._to_functional_tensor(t_clone2)
|
|
torch._mirror_autograd_meta_to(t_clone2, x)
|
|
torch._enable_functionalization(reapply_views=False)
|
|
aot_f_out = f(x) # noqa: F841
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
# frame count and op count are incremented due to re-compilation
|
|
check_count_and_graph(
|
|
3,
|
|
6,
|
|
3,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_: "f32[3, 4]"):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
getitem: "f32[3, 4]" = wrap[0]; wrap = None
|
|
return (getitem,)
|
|
|
|
class wrap_body_0(torch.nn.Module):
|
|
def forward(self, l_x_: "f32[3, 4]"):
|
|
add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None
|
|
return (add_,)
|
|
""",
|
|
)
|
|
|
|
def test_has_torch_function(self):
|
|
class MyTensor:
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
if func is torch.max:
|
|
return torch.tensor(123)
|
|
return func(*args, **kwargs)
|
|
|
|
class LocalSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|
|
|
|
def fn(x):
|
|
return torch.overrides.has_torch_function_unary(
|
|
x
|
|
), torch.overrides.has_torch_function_variadic(x)
|
|
|
|
for test_class in [MyTensor, LocalSubclass]:
|
|
x = test_class()
|
|
ref0 = fn(x)
|
|
ref1 = fn(4)
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
res0 = opt_fn(x)
|
|
res1 = opt_fn(4)
|
|
self.assertEqual(ref0, res0)
|
|
self.assertEqual(ref1, res1)
|
|
|
|
def test_wrapper_subclass_guards_on_inner_tensor(self):
|
|
# Holds an inner tensor, that has a distinct shape from the outer wrapper tensor.
|
|
# Also adds additional guards on the inner tensor's sizes.
|
|
# When the first input to an op has x.shape[0] > 5, we insert an extra add node.
|
|
class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, inner):
|
|
# Double the outer-most dimension
|
|
outer_shape = (inner.shape[0] * 2,) + inner.shape[1:]
|
|
return torch.Tensor._make_wrapper_subclass(
|
|
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
|
|
# Calling the overload that has kwargs causes us to go down the first overload path,
|
|
# which will **always** specialize sizes.
|
|
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
|
|
cls,
|
|
outer_shape,
|
|
inner.stride(),
|
|
None,
|
|
None,
|
|
inner.dtype,
|
|
inner.layout,
|
|
inner.device,
|
|
False,
|
|
inner.requires_grad,
|
|
)
|
|
|
|
def __init__(self, inner):
|
|
self.inner_elem = inner
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["inner_elem"], None
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride):
|
|
return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"])
|
|
|
|
def __repr__(self):
|
|
return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})"
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
args_inner = torch.utils._pytree.tree_map_only(
|
|
DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args
|
|
)
|
|
out_inner = func(*args_inner, **kwargs)
|
|
|
|
# Add guards on the inner tensor's sizes
|
|
if args_inner[0].shape[0] > 3:
|
|
out_inner += 2
|
|
|
|
return DoubleSizeMaybeAddGeThreeTensor(out_inner)
|
|
|
|
curr_var_to_val = None
|
|
curr_var_to_sources = None
|
|
guards = None
|
|
|
|
def backend(gm, args):
|
|
context = torch._guards.TracingContext.get()
|
|
|
|
# Grab info on sources and guards from the shapeenv
|
|
nonlocal curr_var_to_val
|
|
nonlocal curr_var_to_sources
|
|
nonlocal guards
|
|
|
|
guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
|
|
curr_var_to_val = {
|
|
str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items()
|
|
}
|
|
curr_var_to_sources = {
|
|
str(k): v[0].name()
|
|
for k, v in context.fake_mode.shape_env.var_to_sources.items()
|
|
}
|
|
return gm
|
|
|
|
@torch.compile(backend=backend)
|
|
def fn(x):
|
|
if x.shape[0] < 13:
|
|
return torch.mul(x, x)
|
|
else:
|
|
return torch.div(x, x)
|
|
|
|
inp = torch.ones(4, 4)
|
|
|
|
x = DoubleSizeMaybeAddGeThreeTensor(inp)
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
res = fn(x) # noqa: F841
|
|
# During fakeifying, we end up allocating a separate symint
|
|
# for the outer and inner tensor (in this test, s0 is unused).
|
|
expected_var_to_val = {
|
|
"s50": 4,
|
|
"s77": 8,
|
|
}
|
|
expected_var_to_sources = {
|
|
"s50": "L['x'].inner_elem.size()[0]",
|
|
"s77": "L['x'].size()[0]",
|
|
}
|
|
self.assertEqual(curr_var_to_val, expected_var_to_val)
|
|
self.assertEqual(curr_var_to_sources, expected_var_to_sources)
|
|
self.assertExpectedInline(
|
|
"\n".join(guards),
|
|
"""\
|
|
Eq(2*s50, s77)
|
|
2*s50 < 13
|
|
s50 > 3""",
|
|
)
|
|
|
|
def test_wrapper_subclass_with_same_sized_inner_tensor(self):
|
|
# shouldn't recompile for different sizes when dynamic=True
|
|
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
|
|
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7))
|
|
self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True))
|
|
|
|
# should recompile for different data size when dynamic=False
|
|
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
|
|
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
|
|
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
|
|
|
|
# avoid recompile using manual mark_dynamic() for different data size
|
|
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
|
|
# NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size
|
|
torch._dynamo.mark_dynamic(sub1, 0)
|
|
torch._dynamo.mark_dynamic(sub1, 1)
|
|
sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
|
|
self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
|
|
|
|
def test_wrapper_subclass_with_differently_sized_inner_tensor(self):
|
|
# should recompile for different scale size when dynamic=False
|
|
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
|
|
sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
|
|
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
|
|
|
|
# still recompiles using manual mark_dynamic() on outer for different scale size
|
|
sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
|
|
# NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size
|
|
torch._dynamo.mark_dynamic(sub1, 0)
|
|
torch._dynamo.mark_dynamic(sub1, 1)
|
|
sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
|
|
self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False))
|
|
|
|
def test_recompiles_with_optional_inner_tensor(self):
|
|
def f(x):
|
|
return x + 1
|
|
|
|
# sub1 does not have the optional tensor specified while sub2 does
|
|
sub1 = OptionalScaledTensor(torch.randn(2, 4), None)
|
|
sub2 = OptionalScaledTensor(torch.randn(2, 4), torch.randn(2, 4))
|
|
|
|
# sanity check; don't recompile for same input
|
|
self.assertFalse(_recompiles_for_inputs(f, (sub1,), (sub1,), dynamic=True))
|
|
self.assertFalse(_recompiles_for_inputs(f, (sub2,), (sub2,), dynamic=True))
|
|
|
|
# these should recompile; optional tensor changes between specified and unspecified
|
|
self.assertTrue(_recompiles_for_inputs(f, (sub1,), (sub2,), dynamic=True))
|
|
self.assertTrue(_recompiles_for_inputs(f, (sub2,), (sub1,), dynamic=True))
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
self.assertEqual(f(sub1)._data, f_compiled(sub1)._data)
|
|
self.assertEqual(f(sub2)._data, f_compiled(sub2)._data)
|
|
|
|
def test_torch_dispatch_subclass_guard_recompile(self):
|
|
x = torch.ones(2, 2)
|
|
x_two = TwoTensor(x.clone(), x.clone())
|
|
|
|
def fn(w):
|
|
return torch.add(w, 1.0)
|
|
|
|
fn_opt = torch.compile(backend="eager")(fn)
|
|
|
|
ref = fn(x_two)
|
|
res = fn_opt(x_two)
|
|
self.assertEqual(ref, res)
|
|
|
|
# ensure no recompilation on same input type
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
fn_opt(TwoTensor(x + 1, x + 2))
|
|
|
|
# recompile!
|
|
ref = fn(x)
|
|
res = fn_opt(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_tensor_subclass_ctx_guards(self):
|
|
x = CtxSubclassTensor(torch.ones(2), 3)
|
|
x2 = CtxSubclassTensor(torch.ones(2), 3)
|
|
x3 = CtxSubclassTensor(torch.ones(2), 4)
|
|
_check_recompiles(self, lambda x: x * x, (x,), (x2,), False)
|
|
_check_recompiles(self, lambda x: x * x, (x,), (x3,), True)
|
|
|
|
def test_tensor_subclass_ctx_recursive_guards(self):
|
|
x0 = torch.ones(2, 2)
|
|
x1 = CtxSubclassTensor(x0.clone(), 2)
|
|
x2 = CtxSubclassTensor(x0.clone(), 3)
|
|
tt0 = TwoTensor(x0.clone(), x1)
|
|
tt1 = TwoTensor(x0.clone(), x2)
|
|
|
|
_check_recompiles(self, lambda x: x * x, (tt0,), (tt1,), True)
|
|
|
|
def test_tensor_subclass_ctx_custom_guards_override(self):
|
|
class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
|
|
@classmethod
|
|
def __metadata_guard__(cls, orig_data, other):
|
|
return orig_data[0] <= other[0]
|
|
|
|
x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 2)
|
|
x2 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
|
|
x3 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 1)
|
|
_check_recompiles(self, lambda x: x * x, (x,), (x2,), False)
|
|
_check_recompiles(self, lambda x: x * x, (x,), (x3,), True)
|
|
|
|
def test_tensor_subclass_ctx_custom_guards_error_arg_num(self):
|
|
import torch._dynamo.exc
|
|
|
|
class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
|
|
@classmethod
|
|
def __metadata_guard__(cls, y):
|
|
# Shouldn't reach here
|
|
return False
|
|
|
|
x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
|
|
self.assertRaisesRegex(
|
|
torch._dynamo.exc.InternalTorchDynamoError,
|
|
"Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments",
|
|
lambda: torch.compile(lambda x: x * x)(x),
|
|
)
|
|
|
|
def test_tensor_subclass_ctx_custom_guards_error_not_classmethod(self):
|
|
import torch._dynamo.exc
|
|
|
|
class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor):
|
|
def __metadata_guard__(self, x, y):
|
|
return False
|
|
|
|
x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3)
|
|
self.assertRaisesRegex(
|
|
torch._dynamo.exc.InternalTorchDynamoError,
|
|
"Tensor subclass method __metadata_guard__ must be a classmethod",
|
|
lambda: torch.compile(lambda x: x * x)(x),
|
|
)
|
|
|
|
def test_subclass_constructor_proxying(self):
|
|
import dataclasses
|
|
from collections import namedtuple
|
|
from typing import Any
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class SubclassTensorArgs:
|
|
original_shape: torch.Size
|
|
device: torch.device
|
|
inner_meta: Any
|
|
|
|
SubclassTensorArgs2 = namedtuple(
|
|
"SubclassTensorArgs2",
|
|
[
|
|
"original_shape",
|
|
"device",
|
|
"inner_meta",
|
|
],
|
|
)
|
|
|
|
class SubclassTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, meta):
|
|
shape = a.shape
|
|
kwargs = {}
|
|
kwargs["strides"] = a.stride()
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
return out
|
|
|
|
def __init__(self, a, meta):
|
|
self.a = a
|
|
self.meta = meta
|
|
|
|
def __repr__(self):
|
|
a_repr = repr(self.a)
|
|
return f"SubclassTensor({a_repr})"
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["a"], self.meta
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, _, __):
|
|
a = inner_tensors["a"]
|
|
return SubclassTensor(a, meta)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map(
|
|
lambda x: x.a if isinstance(x, SubclassTensor) else x, args
|
|
)
|
|
kwargs_a = pytree.tree_map(
|
|
lambda x: x.a if isinstance(x, SubclassTensor) else x, kwargs
|
|
)
|
|
out_a = func(*args_a, **kwargs_a)
|
|
out = pytree.tree_map(
|
|
lambda x: (
|
|
SubclassTensor(x, SubclassTensorArgs2(x.shape, x.device, None))
|
|
if isinstance(x, torch.Tensor)
|
|
else x
|
|
),
|
|
out_a,
|
|
)
|
|
return return_and_correct_aliasing(func, args, kwargs, out)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def f1(x):
|
|
meta = SubclassTensorArgs(
|
|
x.shape, x.device, SubclassTensorArgs(x.shape, x.device, None)
|
|
)
|
|
out = SubclassTensor(x, meta)
|
|
return out * out
|
|
|
|
x = torch.randn(3, 3)
|
|
f1(x)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def f1(x):
|
|
meta = SubclassTensorArgs2(
|
|
x.shape, x.device, SubclassTensorArgs2(x.shape, x.device, None)
|
|
)
|
|
out = SubclassTensor(x, meta)
|
|
return out * out
|
|
|
|
x = torch.randn(3, 3)
|
|
f1(x)
|
|
|
|
def test_torch_function_subclass_survives_into_aot_autograd(self):
|
|
# If you have a tensor subclass that relies on dispatch into the same op
|
|
# without unwrapping and calling torch._C.DisableTorchFunctionSubclass(),
|
|
# the torch function-ness will survive into AOTAutograd. Today, NestedTensor
|
|
# actually relies on this behavior! Because that torch function logic
|
|
# runs during AOTAutograd, this test tests that there is no logic below
|
|
# that relies torch function that gets unexpectedly disabled after we
|
|
# redispatch from the subclass's torch function.
|
|
class SubTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, t):
|
|
return torch.Tensor._make_wrapper_subclass(
|
|
cls,
|
|
t.shape,
|
|
t.stride(),
|
|
t.storage_offset(),
|
|
torch.contiguous_format,
|
|
t.dtype,
|
|
torch.strided,
|
|
t.device,
|
|
False,
|
|
t.requires_grad,
|
|
"sizes",
|
|
False,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
def __init__(self, t):
|
|
super().__init__()
|
|
self._t = t
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["_t"], {}
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
|
t = inner_tensors["_t"]
|
|
return SubTensor(t)
|
|
|
|
def __repr__(self):
|
|
return f"SubTensor({self._t})"
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
return func(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
kwargs = {} if kwargs is None else kwargs
|
|
new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args)
|
|
output = func(*new_args, **kwargs)
|
|
output = pytree.tree_map_only(
|
|
torch.Tensor, lambda t: SubTensor(t), output
|
|
)
|
|
return output
|
|
|
|
@torch.compile(dynamic=True)
|
|
def f(x):
|
|
return x.unflatten(-1, [2, 5])
|
|
|
|
s = SubTensor(torch.randn(3, 10))
|
|
f(s)
|
|
|
|
# Guard validation upsets the guard
|
|
# https://github.com/pytorch/pytorch/issues/129936
|
|
@unittest.expectedFailure
|
|
def test_recompile_with_symbool_inputs(self):
|
|
def f(pred: bool):
|
|
if pred:
|
|
return torch.ones([3, 4])
|
|
else:
|
|
return torch.ones([4, 3])
|
|
|
|
def test_recompilation(
|
|
f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards
|
|
):
|
|
torch._dynamo.reset()
|
|
shape_env = ShapeEnv()
|
|
backend = torch._dynamo.testing.EagerAndRecordGraphs()
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
f_cond = torch.compile(f, backend=cnt, fullgraph=True)
|
|
with torch._subclasses.fake_tensor.FakeTensorMode(
|
|
shape_env=shape_env
|
|
) as fake_mode:
|
|
fake_inp = fake_mode.from_tensor(
|
|
x,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())]
|
|
),
|
|
)
|
|
for i, size in enumerate(sizes):
|
|
pred = fake_inp.size(0) == size
|
|
f_cond(pred)
|
|
actual = normalize_gm(
|
|
backend.graphs[exp_frame_count[i] - 1].print_readable(
|
|
print_output=False
|
|
)
|
|
)
|
|
actual_guard_str = [str(guard.expr) for guard in shape_env.guards]
|
|
self.assertExpectedInline(actual, exp_graphs[i])
|
|
self.assertEqual(cnt.frame_count, exp_frame_count[i])
|
|
self.assertEqual(actual_guard_str, exp_shape_env_guards[i])
|
|
|
|
true_graph = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self):
|
|
ones: "f32[3, 4]" = torch.ones([3, 4])
|
|
return (ones,)
|
|
"""
|
|
false_graph = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self):
|
|
ones: "f32[4, 3]" = torch.ones([4, 3])
|
|
return (ones,)
|
|
"""
|
|
test_recompilation(
|
|
f,
|
|
torch.randn([3, 4]),
|
|
[3, 3, 4, 5],
|
|
exp_graphs=[true_graph, true_graph, false_graph, false_graph],
|
|
exp_frame_count=[1, 1, 2, 2],
|
|
exp_shape_env_guards=[
|
|
[],
|
|
# s0 is specialized and guarded in outer shape_env when dynamo checks the guards
|
|
["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
|
|
[
|
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
|
"Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
|
|
],
|
|
[
|
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
|
"Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
|
|
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
|
|
],
|
|
],
|
|
)
|
|
|
|
test_recompilation(
|
|
f,
|
|
torch.randn([3, 4]),
|
|
[4, 5, 3, 3],
|
|
exp_graphs=[false_graph, false_graph, true_graph, true_graph],
|
|
exp_frame_count=[1, 1, 2, 2],
|
|
exp_shape_env_guards=[
|
|
[],
|
|
# s0 is specialized and guarded in outer shape_env when dynamo checks the guards
|
|
["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
|
|
[
|
|
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
|
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
|
],
|
|
[
|
|
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
|
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
|
],
|
|
],
|
|
)
|
|
|
|
def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self):
|
|
def f(x_subclass):
|
|
tmp_subclass = torch.add(x, 1)
|
|
return torch.mul(tmp_subclass._scale, tmp_subclass._constant)
|
|
|
|
x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2)
|
|
out_ref = f(x)
|
|
out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
def test_support_bases(self):
|
|
import abc
|
|
|
|
import torch.fx._symbolic_trace
|
|
|
|
class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta):
|
|
def __new__(cls, name, bases, dct):
|
|
x = super().__new__(cls, name, bases, dct)
|
|
x.attr = 100
|
|
return x
|
|
|
|
class Multistreamable(abc.ABC): # noqa: B024
|
|
pass
|
|
|
|
class Foo(Multistreamable, metaclass=Meta):
|
|
pass
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x):
|
|
typ = type(Foo())
|
|
typ.__bases__
|
|
return typ.__bases__
|
|
|
|
self.assertEqual(f(torch.randn(1)), (Multistreamable,))
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def g(x):
|
|
typ = type(Foo())
|
|
typ.__base__
|
|
return typ.__base__
|
|
|
|
self.assertEqual(g(torch.randn(1)), Multistreamable)
|
|
|
|
@parametrize("dynamic", [False, True])
|
|
def test_subclass_views(self, dynamic):
|
|
def _get_views(t): # returns (view: Tensor, expects_raises_false)
|
|
# Note that any closed-over SymInts will be symbolicized during fake-ification.
|
|
yield t.narrow(dim=-1, start=3, length=8), False
|
|
yield t.split(5, -1)[2], False
|
|
yield t.split_with_sizes([9, 6], -1)[1], False
|
|
yield t.unsqueeze(-1).expand(4, 15, 10), False
|
|
yield t.select(-1, 6), False
|
|
# https://github.com/pytorch/pytorch/issues/128649
|
|
yield t[2:3, 5:9], dynamic
|
|
yield t.view(-1, 15), False
|
|
|
|
def f(x):
|
|
return x * 2
|
|
|
|
compiled_f = torch.compile(
|
|
f, backend="aot_eager", fullgraph=True, dynamic=dynamic
|
|
)
|
|
|
|
# Take a view of a subclass to pass as input.
|
|
t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15))
|
|
for view, expects_raises in _get_views(t):
|
|
torch._dynamo.reset()
|
|
out_ref = f(view)
|
|
if expects_raises:
|
|
with self.assertRaises(AssertionError):
|
|
out_test = compiled_f(view)
|
|
else:
|
|
out_test = compiled_f(view)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
|
@parametrize("dynamic", [True, False])
|
|
def test_mark_static_with_subclass_desugaring(self, dynamic):
|
|
from collections.abc import Callable
|
|
from typing import Any, Optional
|
|
|
|
from torch._dynamo.decorators import mark_static_address
|
|
from torch._inductor.compile_fx import compile_fx
|
|
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
|
from torch._inductor.utils import BoxedBool
|
|
|
|
x_inner = torch.ones(4)
|
|
x = TwoTensor(x_inner, x_inner)
|
|
mark_static_address(x, guard=False)
|
|
|
|
def inner_compile(
|
|
gm: torch.fx.GraphModule,
|
|
example_inputs: list[torch.Tensor],
|
|
cudagraphs: Optional[BoxedBool] = None,
|
|
static_input_idxs: Optional[list[int]] = None,
|
|
is_backward: bool = False,
|
|
graph_id: Optional[int] = None,
|
|
cpp_wrapper: bool = False,
|
|
aot_mode: bool = False,
|
|
is_inference: bool = False,
|
|
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
|
|
layout_opt: Optional[bool] = None,
|
|
extern_node_serializer: Optional[Callable[[list[Any]], Any]] = None,
|
|
):
|
|
if dynamic:
|
|
self.assertEqual(static_input_idxs, [2, 3, 4])
|
|
else:
|
|
self.assertEqual(static_input_idxs, [1, 2])
|
|
return gm
|
|
|
|
compiler = functools.partial(compile_fx, inner_compile=inner_compile)
|
|
|
|
@torch.compile(backend=compiler, dynamic=dynamic)
|
|
def fn(t0, t1, t2):
|
|
return t0 + t1 + t2 + 2
|
|
|
|
fn(torch.ones(4), x, torch.ones(4))
|
|
|
|
# copied from common_utils.py::NestedTensorTestCase
|
|
def assertEqualIgnoringNestedInts(self, a, b):
|
|
# unbinding NJTs allows us to compare them as essentially equal without
|
|
# caring about exact nested int comparison
|
|
def _unbind_njts(x):
|
|
if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
|
|
return x.unbind()
|
|
else:
|
|
return x
|
|
|
|
self.assertEqual(
|
|
pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b)
|
|
)
|
|
|
|
def _compile_check(
|
|
self,
|
|
fn,
|
|
inps,
|
|
*,
|
|
dynamic=True,
|
|
fullgraph=True,
|
|
call_backward=False,
|
|
):
|
|
def call_backward_fn(t):
|
|
if t.is_nested:
|
|
from torch.nested._internal.nested_tensor import buffer_from_jagged
|
|
|
|
t = buffer_from_jagged(t)
|
|
return t.sum().backward(retain_graph=True)
|
|
|
|
torch.manual_seed(0)
|
|
fw_compiler = EagerRecordGraphAndInputs()
|
|
bw_compiler = EagerRecordGraphAndInputs()
|
|
compiler_fn = aot_autograd(
|
|
fw_compiler=make_boxed_compiler(fw_compiler),
|
|
bw_compiler=make_boxed_compiler(bw_compiler),
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
keep_inference_input_mutations=True,
|
|
)
|
|
|
|
c = torch.compile(backend=compiler_fn, dynamic=dynamic, fullgraph=fullgraph)(fn)
|
|
for inp in inps:
|
|
expected = fn(*inp)
|
|
# reset the seed for randn to generate the same tensor
|
|
torch.manual_seed(0)
|
|
got = c(*inp)
|
|
self.assertEqualIgnoringNestedInts(expected, got)
|
|
|
|
if call_backward:
|
|
re = pytree.tree_map_only(
|
|
lambda x: isinstance(x, torch.Tensor) and x.requires_grad,
|
|
call_backward_fn,
|
|
expected,
|
|
)
|
|
rg = pytree.tree_map_only(
|
|
lambda x: isinstance(x, torch.Tensor) and x.requires_grad,
|
|
call_backward_fn,
|
|
got,
|
|
)
|
|
self.assertEqualIgnoringNestedInts(re, rg)
|
|
|
|
if call_backward:
|
|
return fw_compiler.graphs, bw_compiler.graphs
|
|
return fw_compiler.graphs, None
|
|
|
|
def test_tensor_subclass_TwoTensor_simple(self):
|
|
def f(tt):
|
|
return tt * tt.size()[0]
|
|
|
|
a = torch.ones(3, 4, requires_grad=True)
|
|
b = a.detach().clone().requires_grad_(True)
|
|
tt = TwoTensor(a, b)
|
|
|
|
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
|
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
|
|
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
|
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
|
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
|
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
):
|
|
mul: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
|
|
mul_3: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
|
|
return (
|
|
mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
|
mul_3, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
|
primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
primals_1, # SavedForBackwardsAOTOutput(idx=0)
|
|
primals_5, # SavedForBackwardsAOTOutput(idx=1)
|
|
primals_7, # SavedForBackwardsAOTOutput(idx=2)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
|
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
tangents_1: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
|
tangents_2: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
|
):
|
|
mul_8: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
|
|
mul_9: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
|
|
return (
|
|
None, # None
|
|
None, # None
|
|
mul_8, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
|
mul_9, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
|
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_tensor_subclass_TwoTensor_clone_view(self):
|
|
def f(tt):
|
|
y = tt.clone()
|
|
return y.view(y.shape[1], y.shape[0])
|
|
|
|
a = torch.ones(3, 4, requires_grad=True)
|
|
b = a.clone()
|
|
tt = TwoTensor(a, b)
|
|
|
|
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
|
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
|
|
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
|
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
|
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
|
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
):
|
|
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
|
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
|
|
|
view: "f32[s16, s47]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
|
|
view_1: "f32[s16, s47]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
|
|
return (
|
|
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
|
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
|
primals_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
|
primals_5, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
|
primals_7, # SavedForBackwardsAOTOutput(idx=1)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
tangents_1: "f32[s16, s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
|
tangents_2: "f32[s16, s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
|
):
|
|
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
|
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
|
return (
|
|
None, # None
|
|
None, # None
|
|
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
|
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
|
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_tensor_subclass_TwoTensor_mul(self):
|
|
def f(tt, a, b):
|
|
s0, s1 = a.size()
|
|
s2, s3 = b.size()
|
|
# return tt * a.size()[1]
|
|
return tt * s0 * s1 * s2 * s3
|
|
|
|
a = torch.ones(3, 4, requires_grad=True)
|
|
b = a.clone()
|
|
tt = TwoTensor(a, b)
|
|
|
|
fw, bw = self._compile_check(f, [(tt, a, b)], dynamic=True, call_backward=True)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s97)", # PlainAOTInput(idx=0)
|
|
primals_2: "Sym(s98)", # PlainAOTInput(idx=1)
|
|
primals_3: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
|
primals_4: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
|
primals_5: "Sym(s97)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_6: "Sym(s98)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
|
primals_7: "Sym(s98)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
):
|
|
mul: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
|
|
mul_3: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
|
|
mul_8: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None
|
|
mul_11: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_3, primals_2); mul_3 = None
|
|
mul_16: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_8, primals_1); mul_8 = None
|
|
mul_19: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None
|
|
mul_24: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None
|
|
mul_27: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None
|
|
return (
|
|
mul_24, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
|
mul_27, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
|
primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
primals_1, # SavedForBackwardsAOTOutput(idx=0)
|
|
primals_2, # SavedForBackwardsAOTOutput(idx=1)
|
|
primals_5, # SavedForBackwardsAOTOutput(idx=2)
|
|
primals_7, # SavedForBackwardsAOTOutput(idx=3)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s97)", # PlainAOTInput(idx=0)
|
|
primals_2: "Sym(s98)", # PlainAOTInput(idx=1)
|
|
primals_5: "Sym(s97)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_7: "Sym(s98)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
tangents_1: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
|
tangents_2: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
|
):
|
|
mul_32: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None
|
|
mul_33: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None
|
|
mul_34: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None
|
|
mul_35: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_33, primals_1); mul_33 = None
|
|
mul_36: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_34, primals_2); mul_34 = None
|
|
mul_37: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None
|
|
mul_38: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None
|
|
mul_39: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None
|
|
return (
|
|
None, # None
|
|
None, # None
|
|
mul_38, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
|
mul_39, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
|
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_tensor_subclass_TwoTensor_view(self):
|
|
def f(tt):
|
|
y = tt.clone()
|
|
return y.view(y.shape[0], y.shape[1])
|
|
|
|
a = torch.ones(3, 4, requires_grad=True)
|
|
b = a.clone()
|
|
tt = TwoTensor(a, b)
|
|
|
|
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
|
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
|
|
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
|
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
|
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
|
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
):
|
|
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
|
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
|
|
|
view: "f32[s47, s16]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
|
|
view_1: "f32[s47, s16]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
|
|
return (
|
|
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
|
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
|
primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
|
primals_7, # SavedForBackwardsAOTOutput(idx=1)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
tangents_1: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
|
tangents_2: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
|
):
|
|
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
|
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
|
return (
|
|
None, # None
|
|
None, # None
|
|
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
|
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
|
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_tensor_subclass_TwoTensor_view_mul(self):
|
|
def f(tt):
|
|
y = tt.clone()
|
|
return y.view(y.shape[0] * y.shape[1])
|
|
|
|
a = torch.ones(3, 4, requires_grad=True)
|
|
b = a.clone()
|
|
tt = TwoTensor(a, b)
|
|
|
|
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
|
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
|
|
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
|
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
|
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
|
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
):
|
|
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
|
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
|
|
|
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
|
|
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
|
|
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
|
|
return (
|
|
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
|
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
|
mul_6, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
|
primals_7, # SavedForBackwardsAOTOutput(idx=1)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
tangents_1: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
|
tangents_2: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
|
):
|
|
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
|
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
|
return (
|
|
None, # None
|
|
None, # None
|
|
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
|
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
|
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_tensor_subclass_TwoTensor_return_tensor_and_subclass(self):
|
|
def f(tt):
|
|
y = tt.clone()
|
|
return y.a, y.view(y.shape[0] * y.shape[1])
|
|
|
|
a = torch.ones(3, 4, requires_grad=True)
|
|
b = a.clone()
|
|
tt = TwoTensor(a, b)
|
|
|
|
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
|
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
|
|
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
|
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
|
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
|
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
):
|
|
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
|
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
|
|
|
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
|
|
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6])
|
|
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
|
|
return (
|
|
clone, # PlainAOTOutput(idx=0)
|
|
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
|
|
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
|
|
mul_6, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0)
|
|
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
|
primals_7, # SavedForBackwardsAOTOutput(idx=1)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
|
tangents_1: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
|
|
tangents_2: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
|
|
):
|
|
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
|
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
|
return (
|
|
None, # None
|
|
None, # None
|
|
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
|
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
|
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
@unittest.expectedFailure
|
|
def test_tensor_subclass_TwoTensor_return_multiple(self):
|
|
def f(tt):
|
|
y = tt.clone()
|
|
z = tt.clone()
|
|
return y.a, y.view(y.shape[0] * y.shape[1]), y.b, z.view(-1)
|
|
|
|
a = torch.ones(3, 4, requires_grad=True)
|
|
b = a.clone()
|
|
tt = TwoTensor(a, b)
|
|
|
|
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_1: "f32[3, 4]", primals_2: "f32[3, 4]", primals_3: "Sym(3)", primals_4: "Sym(4)", primals_5: "Sym(3)", primals_6: "Sym(4)"):
|
|
clone: "f32[3, 4]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
|
clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
|
|
|
mul: "Sym(12)" = primals_5 * primals_6
|
|
view: "f32[12]" = torch.ops.aten.view.default(clone, [mul])
|
|
view_1: "f32[12]" = torch.ops.aten.view.default(clone_1, [mul]); clone_1 = None
|
|
return [clone, view, view_1, mul, primals_5, primals_6]
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_5: "Sym(3)", primals_6: "Sym(4)", tangents_1: "f32[12]", tangents_2: "f32[12]"):
|
|
view_2: "f32[3, 4]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_6]); tangents_1 = None
|
|
view_3: "f32[3, 4]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_6]); tangents_2 = primals_5 = primals_6 = None
|
|
return [view_2, view_3, None, None]
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_tensor_subclass_TwoTensor_automatic_dynamic_shapes(self):
|
|
def f(tt):
|
|
y = tt.clone()
|
|
return y.a, y.view(-1), y.b
|
|
|
|
a = torch.ones(3, 4, requires_grad=True)
|
|
b = a.clone()
|
|
tt1 = TwoTensor(a, b)
|
|
|
|
a = torch.ones(3, 5, requires_grad=True)
|
|
b = a.clone()
|
|
tt2 = TwoTensor(a, b)
|
|
|
|
fw, bw = self._compile_check(
|
|
f, [(tt1,), (tt2,)], dynamic=None, call_backward=True
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "f32[3, 4]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='a')
|
|
primals_2: "f32[3, 4]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='b')
|
|
):
|
|
clone: "f32[3, 4]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
|
clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
|
|
|
view: "f32[12]" = torch.ops.aten.view.default(clone, [-1])
|
|
view_1: "f32[12]" = torch.ops.aten.view.default(clone_1, [-1])
|
|
return (
|
|
clone, # PlainAOTOutput(idx=0)
|
|
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
|
|
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
|
|
clone_1, # PlainAOTOutput(idx=2)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[1].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s16)", # PlainAOTInput(idx=0)
|
|
primals_2: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='a')
|
|
primals_3: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='b')
|
|
primals_4: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=1), idx=1)
|
|
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
|
|
):
|
|
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
|
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
|
|
|
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
|
|
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
|
|
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
|
|
return (
|
|
clone, # PlainAOTOutput(idx=0)
|
|
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
|
|
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
|
|
sym_size_int_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0)
|
|
clone_1, # PlainAOTOutput(idx=2)
|
|
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
tangents_1: "f32[12]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
|
|
tangents_2: "f32[12]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
|
|
):
|
|
view_2: "f32[3, 4]" = torch.ops.aten.view.default(tangents_1, [3, 4]); tangents_1 = None
|
|
view_3: "f32[3, 4]" = torch.ops.aten.view.default(tangents_2, [3, 4]); tangents_2 = None
|
|
return (
|
|
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a')
|
|
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b')
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[1].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
|
|
tangents_1: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
|
|
tangents_2: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
|
|
):
|
|
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
|
|
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
|
|
return (
|
|
None, # None
|
|
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='a')
|
|
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1)
|
|
primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_tensor_subclass_TwoTensor_mark_dynamic_shapes(self):
|
|
def f(tt):
|
|
y = tt.clone()
|
|
return y.a, y.view(-1), y.b
|
|
|
|
a = torch.ones(3, 4, requires_grad=True)
|
|
b = a.clone()
|
|
tt = TwoTensor(a, b)
|
|
torch._dynamo.mark_dynamic(tt, 1)
|
|
|
|
fw, bw = self._compile_check(
|
|
f,
|
|
[
|
|
(tt,),
|
|
],
|
|
dynamic=None,
|
|
call_backward=True,
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s16)", # PlainAOTInput(idx=0)
|
|
primals_2: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='a')
|
|
primals_3: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='b')
|
|
primals_4: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=1), idx=1)
|
|
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
|
|
):
|
|
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
|
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
|
|
|
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
|
|
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
|
|
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
|
|
return (
|
|
clone, # PlainAOTOutput(idx=0)
|
|
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
|
|
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
|
|
sym_size_int_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0)
|
|
clone_1, # PlainAOTOutput(idx=2)
|
|
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
|
|
tangents_1: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
|
|
tangents_2: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
|
|
):
|
|
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
|
|
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
|
|
return (
|
|
None, # None
|
|
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='a')
|
|
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='b')
|
|
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1)
|
|
primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_tensor_subclass_TwoTensor_different_shape(self):
|
|
def f(tt):
|
|
y = tt.clone()
|
|
return y.view(3, 2, 4)
|
|
|
|
a = torch.ones((2 * 4 * 3), requires_grad=True)
|
|
b = a.clone()
|
|
tt = TwoTensor(a, b)
|
|
|
|
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "f32[24]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='a')
|
|
primals_2: "f32[24]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='b')
|
|
):
|
|
clone: "f32[24]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
|
clone_1: "f32[24]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
|
|
|
view: "f32[3, 2, 4]" = torch.ops.aten.view.default(clone, [3, 2, 4]); clone = None
|
|
view_1: "f32[3, 2, 4]" = torch.ops.aten.view.default(clone_1, [3, 2, 4]); clone_1 = None
|
|
return (
|
|
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
|
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
tangents_1: "f32[3, 2, 4]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
|
tangents_2: "f32[3, 2, 4]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
|
):
|
|
view_2: "f32[24]" = torch.ops.aten.view.default(tangents_1, [24]); tangents_1 = None
|
|
view_3: "f32[24]" = torch.ops.aten.view.default(tangents_2, [24]); tangents_2 = None
|
|
return (
|
|
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a')
|
|
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b')
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_tensor_subclass_TwoTensor_return_shape(self):
|
|
@torch.compile(backend="aot_eager", dynamic=True)
|
|
def fn(x):
|
|
return x.clone().view(x.shape[0] * x.shape[1])
|
|
|
|
a = torch.ones(2, 3)
|
|
b = a.clone()
|
|
tt = TwoTensor(a, b)
|
|
out = fn(tt)
|
|
self.assertEqual(tt.view(2 * 3), out)
|
|
self.assertEqual(out.shape, (6,))
|
|
|
|
def test_tensor_subclass_TwoTensor_nested(self):
|
|
@torch.compile(backend="aot_eager", dynamic=True)
|
|
def f(x, i, y):
|
|
out1 = x.sin() + i.sin() + y.sin()
|
|
val1 = x.shape[0] * i.shape[1] * y.shape[0]
|
|
return out1 * val1
|
|
|
|
i = torch.randn(2, 2, requires_grad=True)
|
|
x = TwoTensor(i, i.clone())
|
|
y = TwoTensor(x.clone(), x.clone())
|
|
|
|
out = f(x, i, y)
|
|
|
|
x_test = x.detach().clone().requires_grad_(True)
|
|
i_test = i.detach().clone().requires_grad_(True)
|
|
y_test = y.detach().clone().requires_grad_(True)
|
|
|
|
out_test = f(x_test, i_test, y_test)
|
|
torch.allclose(out, out_test)
|
|
|
|
out.sum().backward()
|
|
out_test.sum().backward()
|
|
torch.allclose(x.grad, x_test.grad)
|
|
torch.allclose(i.grad, i_test.grad)
|
|
torch.allclose(y.grad, y_test.grad)
|
|
|
|
def test_subclass_TwoTensor_TwoTensor_TwoTensor(self):
|
|
@torch.compile(backend="aot_eager", dynamic=True)
|
|
def f(x):
|
|
return x.sin()
|
|
|
|
data = torch.randn(2, 3)
|
|
s = TwoTensor(data, data.clone())
|
|
y = TwoTensor(s, s.clone())
|
|
z = TwoTensor(s, y)
|
|
out = f(z)
|
|
self.assertEqual(out, z.sin())
|
|
|
|
def test_subclass_TwoTensor_nested_diff_sizes(self):
|
|
class TT(TwoTensor):
|
|
@staticmethod
|
|
def __new__(cls, a, b, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
assert (
|
|
a.device == b.device
|
|
and a.layout == b.layout
|
|
and a.requires_grad == b.requires_grad
|
|
and a.dtype == b.dtype
|
|
)
|
|
shape = outer_size
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
return out
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert meta is None
|
|
a, b = inner_tensors["a"], inner_tensors["b"]
|
|
if type(a) is torch.Tensor:
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return TT(a, b, outer_size, outer_stride)
|
|
|
|
@torch.compile(dynamic=True)
|
|
def f(x, y):
|
|
tmp1 = x.sin()
|
|
tmp2 = y.sin()
|
|
return tmp1.sum(), tmp2.sum()
|
|
|
|
x = TT(
|
|
TT(
|
|
torch.randn(3, 4),
|
|
torch.randn(5, 6, 7),
|
|
),
|
|
TT(
|
|
torch.randn(4),
|
|
torch.randn(2, 3),
|
|
),
|
|
)
|
|
|
|
y = TT(
|
|
torch.randn(2, 3, 4, 5),
|
|
TT(
|
|
torch.randn(3, 4),
|
|
torch.randn(5),
|
|
),
|
|
)
|
|
|
|
out = f(x, y)
|
|
self.assertEqual(out, (x.sin().sum(), y.sin().sum()))
|
|
|
|
def test_njt_subclass_simple(self):
|
|
def f(nt):
|
|
y = nt.clone()
|
|
return y * y.size(0)
|
|
|
|
nt, _ = get_jagged_tensor(((2, 3, 4), 5), None, True)
|
|
|
|
fw, bw = self._compile_check(f, [(nt,)], dynamic=True, call_backward=True)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
|
|
primals_2: "Sym(s71)", # PlainAOTInput(idx=1)
|
|
primals_3: "Sym(s55)", # PlainAOTInput(idx=2)
|
|
primals_4: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values')
|
|
primals_5: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets')
|
|
primals_6: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor')
|
|
primals_7: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor')
|
|
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
|
primals_9: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2)
|
|
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
|
):
|
|
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
|
|
|
mul: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
|
|
return (
|
|
mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
|
|
primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')
|
|
primals_6, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
|
|
primals_7, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
|
|
primals_8, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
primals_10, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
|
|
primals_10, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
|
primals_1, # SavedForBackwardsAOTOutput(idx=0)
|
|
primals_8, # SavedForBackwardsAOTOutput(idx=1)
|
|
primals_10, # SavedForBackwardsAOTOutput(idx=2)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
|
|
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
|
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
|
tangents_1: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values')
|
|
tangents_2: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_offsets')
|
|
tangents_3: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_min_seqlen_tensor')
|
|
tangents_4: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_max_seqlen_tensor')
|
|
):
|
|
mul_1: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
|
|
return (
|
|
None, # None
|
|
None, # None
|
|
None, # None
|
|
mul_1, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_values')
|
|
tangents_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_offsets')
|
|
tangents_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor')
|
|
tangents_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor')
|
|
primals_8, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0)
|
|
primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2)
|
|
primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_njt_subclass_from_cat(self):
|
|
# create from an existing NJT
|
|
def f(nt):
|
|
y = nt.clone()
|
|
z = torch.cat([y, y], dim=-1)
|
|
return z
|
|
|
|
nt, _ = get_jagged_tensor(((2, 3, 4), 5), None, True)
|
|
|
|
fw, bw = self._compile_check(f, [(nt,)], dynamic=True, call_backward=True)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
|
|
primals_2: "Sym(s71)", # PlainAOTInput(idx=1)
|
|
primals_3: "Sym(s55)", # PlainAOTInput(idx=2)
|
|
primals_4: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values')
|
|
primals_5: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets')
|
|
primals_6: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor')
|
|
primals_7: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor')
|
|
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
|
primals_9: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2)
|
|
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
|
):
|
|
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
|
|
|
cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
|
|
add_2: "Sym(2*s55)" = primals_10 + primals_10
|
|
return (
|
|
cat, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
|
|
primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')
|
|
primals_6, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
|
|
primals_7, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
|
|
primals_8, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
|
add_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
|
|
add_2, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
|
primals_8, # SavedForBackwardsAOTOutput(idx=0)
|
|
primals_10, # SavedForBackwardsAOTOutput(idx=1)
|
|
add_2, # SavedForBackwardsAOTOutput(idx=2)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
|
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
|
add_2: "Sym(2*s55)",
|
|
tangents_1: "f64[s64, 2*s55]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values')
|
|
tangents_2: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_offsets')
|
|
tangents_3: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_min_seqlen_tensor')
|
|
tangents_4: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_max_seqlen_tensor')
|
|
):
|
|
slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
|
|
slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
|
|
|
|
add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
|
|
return (
|
|
None, # None
|
|
None, # None
|
|
None, # None
|
|
add_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_values')
|
|
tangents_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_offsets')
|
|
tangents_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor')
|
|
tangents_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor')
|
|
primals_8, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0)
|
|
primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2)
|
|
primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_njt_subclass_from_buffer(self):
|
|
# create the NJT from a buffer(?)
|
|
def f(nt):
|
|
nested_size = ((2, 3, 4), 5)
|
|
offsets = None
|
|
nt2, _ = get_jagged_tensor(nested_size, offsets, requires_grad=False)
|
|
nt3 = torch.cat([nt2, nt], dim=-1)
|
|
return nt3.sin() * nt3.size(0)
|
|
|
|
nested_size = ((2, 3, 4), 5)
|
|
offsets = None
|
|
nt, _ = get_jagged_tensor(nested_size, offsets, requires_grad=False)
|
|
|
|
fw, _ = self._compile_check(
|
|
f,
|
|
[(nt,)],
|
|
dynamic=True,
|
|
call_backward=False, # we cannot set requires_grad=True inside a compile region
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
|
"""\
|
|
class <lambda>(torch.nn.Module):
|
|
def forward(
|
|
self,
|
|
arg0_1: "Sym(s51)", # PlainAOTInput(idx=0)
|
|
arg1_1: "Sym(s71)", # PlainAOTInput(idx=1)
|
|
arg2_1: "Sym(s55)", # PlainAOTInput(idx=2)
|
|
arg3_1: "f64[9, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values')
|
|
arg4_1: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets')
|
|
arg5_1: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor')
|
|
arg6_1: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor')
|
|
arg7_1: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
|
arg8_1: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2)
|
|
arg9_1: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
|
):
|
|
randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
|
randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
|
randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
|
|
|
cat: "f64[9, 5]" = torch.ops.aten.cat.default([randn, randn_1, randn_2]); randn = randn_1 = randn_2 = None
|
|
zeros: "i64[1]" = torch.ops.aten.zeros.default([1], dtype = torch.int64, device = device(type='cpu'), pin_memory = False)
|
|
_tensor_constant0: "i64[3]" = self._tensor_constant0
|
|
lift_fresh_copy: "i64[3]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
|
|
cumsum: "i64[3]" = torch.ops.aten.cumsum.default(lift_fresh_copy, 0); lift_fresh_copy = None
|
|
cat_1: "i64[4]" = torch.ops.aten.cat.default([zeros, cumsum]); zeros = cumsum = None
|
|
zeros_1: "f32[2, 0]" = torch.ops.aten.zeros.default([2, 0], device = device(type='cpu'), pin_memory = False)
|
|
zeros_2: "f32[4, 0]" = torch.ops.aten.zeros.default([4, 0], device = device(type='cpu'), pin_memory = False)
|
|
|
|
cat_2: "f64[9, s55 + 5]" = torch.ops.aten.cat.default([cat, arg3_1], 1); cat = arg3_1 = None
|
|
|
|
sin: "f64[9, s55 + 5]" = torch.ops.aten.sin.default(cat_2)
|
|
mul: "f64[9, s55 + 5]" = torch.ops.aten.mul.Tensor(sin, 3); sin = None
|
|
|
|
sym_size_int: "Sym(s55 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None
|
|
sym_stride_int: "Sym(s55 + 5)" = torch.ops.aten.sym_stride.int(mul, 0)
|
|
return (
|
|
mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
|
|
cat_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')
|
|
zeros_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
|
|
zeros_2, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
|
|
sym_size_int, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
|
|
sym_stride_int, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
|
)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(SubclassTests)
|
|
|
|
|
|
class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
|
|
def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
|
|
return get_jagged_tensor(nested_size, offsets, requires_grad)
|
|
|
|
def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
|
|
# Makes a jagged tensor with N constituent tensors with size
|
|
# as specified ((S0, S1, S2), D)
|
|
max_dim = (starts + lengths).max()
|
|
values_tensor = torch.randn(
|
|
starts.shape[0],
|
|
max_dim.item(),
|
|
inner_dim,
|
|
requires_grad=requires_grad,
|
|
dtype=torch.float64,
|
|
)
|
|
return jagged_from_tensor_and_lengths(values_tensor, starts, lengths)
|
|
|
|
def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles):
|
|
_check_recompiles(self, fn, inputs1, inputs2, expected_recompiles)
|
|
|
|
def test_unary_does_not_recompile(self):
|
|
nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
|
|
nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None)
|
|
self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False)
|
|
|
|
def test_binary_does_not_recompile(self):
|
|
def binary(nt1, nt2):
|
|
if nt1.shape == nt2.shape:
|
|
return nt1 + nt2
|
|
else:
|
|
return nt1.sin()
|
|
|
|
# NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
|
|
# This causes a recompile later on when it realizes the batch and last dim
|
|
# should not always be equal. To avoid that, we use (3, j0, 5) here.
|
|
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
|
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
|
|
nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None)
|
|
nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets)
|
|
self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False)
|
|
|
|
def test_binary_recompiles(self):
|
|
def binary(nt1, nt2):
|
|
if nt1.shape == nt2.shape:
|
|
return nt1 + nt2
|
|
else:
|
|
return nt1.sin()
|
|
|
|
# Binary recompiles because singleton ints no longer match
|
|
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
|
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets)
|
|
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
|
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
|
|
|
|
def _validate_compile(self, fn, arg_fn):
|
|
def _gen_grad_outputs(out_val):
|
|
if isinstance(out_val, (list, tuple)):
|
|
return tuple(torch.ones_like(c) for c in out_val)
|
|
else:
|
|
return (torch.ones_like(out_val),)
|
|
|
|
with self.branch_nested_state():
|
|
from torch.nested._internal.nested_tensor import _tensor_symint_registry
|
|
|
|
# Validate that compilation does not modify eager state
|
|
registry_before = list(_tensor_symint_registry.items())
|
|
count_before = torch.nested._internal.nested_tensor._tensor_id_counter
|
|
|
|
guards_exported = []
|
|
guards_failed = []
|
|
|
|
def append_guard_export(guards):
|
|
for g in guards:
|
|
if g.code_list is not None:
|
|
guards_exported.append(g.code_list[0])
|
|
|
|
def append_guard_fail(guards):
|
|
guards_failed.extend(guards)
|
|
|
|
compiled = torch._dynamo.optimize(
|
|
nopython=True,
|
|
backend="aot_eager",
|
|
guard_export_fn=append_guard_export,
|
|
guard_fail_fn=append_guard_fail,
|
|
)(fn)
|
|
registry_after = list(_tensor_symint_registry.items())
|
|
count_after = torch.nested._internal.nested_tensor._tensor_id_counter
|
|
self.assertEqual(registry_before, registry_after)
|
|
self.assertEqual(count_before, count_after)
|
|
|
|
args = arg_fn()
|
|
compile_out = compiled(*args)
|
|
compile_grads = []
|
|
g_args = [arg for arg in args if arg.requires_grad]
|
|
if len(g_args) > 0:
|
|
compile_grad_outputs = _gen_grad_outputs(compile_out)
|
|
compile_grads = torch.autograd.grad(
|
|
compile_out, inputs=g_args, grad_outputs=compile_grad_outputs
|
|
)
|
|
|
|
with self.branch_nested_state():
|
|
args = arg_fn()
|
|
ref_out = fn(*args)
|
|
ref_grads = []
|
|
g_args = [arg for arg in args if arg.requires_grad]
|
|
if len(g_args) > 0:
|
|
ref_grad_outputs = _gen_grad_outputs(ref_out)
|
|
ref_grads = torch.autograd.grad(
|
|
ref_out, inputs=g_args, grad_outputs=ref_grad_outputs
|
|
)
|
|
|
|
# Validate correctness forward
|
|
if isinstance(compile_out, (list, tuple)):
|
|
# TODO: Fix assertEqual() to support NJTs so this isn't necessary
|
|
self.assertEqual(len(compile_out), len(ref_out))
|
|
for c, r in zip(compile_out, ref_out):
|
|
self.assertEqualIgnoringNestedInts(c, r)
|
|
else:
|
|
self.assertEqualIgnoringNestedInts(compile_out, ref_out)
|
|
|
|
# Validate correctness backward
|
|
for compile_grad, ref_grad in zip(compile_grads, ref_grads):
|
|
self.assertEqualIgnoringNestedInts(compile_grad, ref_grad)
|
|
|
|
return guards_exported, guards_failed
|
|
|
|
def test_in_graph_is_nested_call(self):
|
|
def f(nt):
|
|
if nt.is_nested:
|
|
return nt + 2
|
|
else:
|
|
return nt + 1
|
|
|
|
cnt = CompileCounterWithBackend("aot_eager")
|
|
compiled_f = torch.compile(f, backend=cnt, fullgraph=True)
|
|
nt, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
|
output = compiled_f(nt)
|
|
output.backward(torch.ones_like(output))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(len(cnt.graphs), 1)
|
|
graph = cnt.graphs[0]
|
|
norm_graph = normalize_gm(graph.print_readable(print_output=False))
|
|
|
|
# expect -no- is_nested calls within the graph
|
|
self.assertExpectedInline(
|
|
norm_graph,
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, s71: "Sym(s71)", L_nt_: "f64[3, s71, 5]"):
|
|
l_nt_ = L_nt_
|
|
|
|
add: "f64[3, s71, 5]" = l_nt_ + 2; l_nt_ = None
|
|
return (add,)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
# Note: [What kind of guards are involved in nested tensor compilation]
|
|
#
|
|
# Until we implement UnionFind, dynamic shapes guards are not involved.
|
|
# we rely only on dynamo's tensor aliasing guards.
|
|
#
|
|
# This is possible because dynamo able to generate tensor aliasing guards
|
|
# not only for the outer tensor, but also for the inner tensor.
|
|
#
|
|
# The case where dynamic shapes guards would eventually come into play is
|
|
# when my inputs are (1) two non-aliased tensors, but (2) declared as
|
|
# equal using a "trust me assert equal" API.
|
|
|
|
# Note: [Compiling nested tensor global state]
|
|
#
|
|
# Today there are two pieces of global eager state that NJTs deals with:
|
|
# - tensor_id_counter: a global counter that assigns unique ids to tensors
|
|
# - tensor_symint_registry: maps tensor to nested int
|
|
# - this is used in eager only (we should get rid of this because it is
|
|
# not necessary to cache nested int in eager)
|
|
# - during tracing, we DO need to cache nested int, but we do so on
|
|
# the FakeTensor.
|
|
#
|
|
# Ideally we would like to satisfy the following:
|
|
# - (1) The eager state is not mutated during tracing
|
|
# - (2) Running the compiled function should mutate the eager state in the
|
|
# same way that running the eager function would
|
|
# (a) The global counter should be incremented
|
|
# (b) The registry is updated in the same way
|
|
#
|
|
# Today we can satisfy (1) and (2a) but cannot satisfy (2b)
|
|
#
|
|
# Today, (1) is satisfied because we maintain a separate counter during
|
|
# tracing, and cache nested int on FakeTensor instead of relying on
|
|
# tensor_symint_registry.
|
|
#
|
|
# (2) is cannot be completely satisfied because we trace away the
|
|
# side-effectful operations (which we can fix this by wrapping the
|
|
# side-effectful operations in a custom op, and threading through effect
|
|
# tokens.) The current plan is to do that in the UnionFind impl.
|
|
#
|
|
# Interestingly, despite this, the state is mutated in a way that is somewhat
|
|
# close to what we want, e.g. if I construct a nested tensor using an
|
|
# offsets in the compiled region and return it, AOTAutograd runtime wrapper
|
|
# must rewrap the inner->inner graph outputs back into subclass. This
|
|
# triggers the eager logic to run, updating the counter and registry.
|
|
#
|
|
# Notably however, compile differs in two ways from eager:
|
|
# (1) The order in which the offsets are assigned ids is different
|
|
# the registry would be set in the order the offsets are returned
|
|
# which is not necessarily the same order as they were constructed.
|
|
# (2) If a NestedTensor is not returned, then the AOTAutograd wrapping
|
|
# logic will not be triggered.
|
|
#
|
|
# I claim that correctness is not affected by these differences today.
|
|
# e.g. there is never the case where two distinct offsets silently share
|
|
# the same id.
|
|
#
|
|
# (1) is clearly not a problem, and (2) should only be a problem if
|
|
# the nested int is returned on its own, without the corresponding NJT
|
|
# being returned. This is not a problem in the current implementation
|
|
# because returning only a shape is not supported!
|
|
|
|
# Note: [Creating symbolic nested int]
|
|
#
|
|
# We must create a symbolic nested int when we construct a nested tensor
|
|
# from a tensor. There are two main cases:
|
|
#
|
|
# 1. The offsets has NOT been used to construct a NJT
|
|
# - Create a new plain nested int with current val of fake nt id counter
|
|
# - Increment the fake nt id counter
|
|
# - Create a new symint with plain nested int as hint
|
|
# 2. The offsets HAS been used to construct a NJT
|
|
# - Create a new symint with plain nested int as hint
|
|
#
|
|
# More details on case 2:
|
|
# - During fakification of the offsets, we check the eager registry, and
|
|
# if the tensor HAS been used to construct a NJT,
|
|
# we create a symint, with the existing nested int as hint, and cache
|
|
# it on to the FakeTensor.
|
|
#
|
|
# [ Always use ephemeral source ]
|
|
#
|
|
# We create the new symint ALWAYS with ephemeral source whether that is
|
|
# in case (1) or (2) even though we could've had a proper source for case (2).
|
|
# Using a proper source would enable a few more (edge) cases, but since
|
|
# we plan to handle things more holistically in the future anyway, we don't
|
|
# bother doing so today.
|
|
#
|
|
# Using an ephemeral source has some consequences. But we are happy if
|
|
# - We do not silently miss recompiles, e.g. we guard when necessary.
|
|
# We know that this is true, because dynamo guards alone are already
|
|
# sufficient.
|
|
# - We are not producing errors for the cases we care about
|
|
#
|
|
# The main case we care about is when we guard that two shapes are equal.
|
|
# In this case, the replacements logic would simplify away the ephemeral
|
|
# symbol, and there is no error produced.
|
|
# The unsupported case is when we guard that two shapes are not equal, in
|
|
# which, we will try and fail to generate a guard.
|
|
|
|
#
|
|
# Case 1: in-graph construction where the offsets are passed as inputs
|
|
#
|
|
def test_in_graph_construction_from_input(self):
|
|
# The offsets is passed as an input
|
|
def fn(values, offsets):
|
|
return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2
|
|
|
|
values = torch.randn(10, 5, requires_grad=True)
|
|
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
|
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
|
|
|
|
# Do not specialize on the offsets
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
different_offsets = torch.tensor([0, 1, 5, 10], dtype=torch.int64)
|
|
self._validate_compile(fn, arg_fn=lambda: (values, different_offsets))
|
|
|
|
def test_in_graph_construction_from_input_2(self):
|
|
# Construct two NJTs, both are passed as inputs
|
|
def fn(values, offsets1, offsets2):
|
|
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets1)
|
|
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
|
|
return nt2, nt1
|
|
|
|
values = torch.randn(10, 5, requires_grad=True)
|
|
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
|
offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
|
|
# 1. Offsets are different
|
|
guards_exported, guards_failed = self._validate_compile(
|
|
fn, arg_fn=lambda: (values, offsets, offsets2)
|
|
)
|
|
self.assertEqual(len(guards_failed), 0)
|
|
self.assertNotIn("L['offsets1'] is L['offsets2']", guards_exported)
|
|
|
|
# TODO
|
|
# 2. Offsets are the same
|
|
new_guards_exported, _ = self._validate_compile(
|
|
fn, arg_fn=lambda: (values, offsets, offsets)
|
|
)
|
|
self.assertTrue(any("Duplicate tensors found" in g for g in guards_failed))
|
|
self.assertIn("L['offsets1'] is L['offsets2']", new_guards_exported)
|
|
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
offsets3 = offsets.clone()
|
|
self._validate_compile(fn, arg_fn=lambda: (values, offsets3, offsets3))
|
|
|
|
# Do a binary op
|
|
def fn(values, offsets, offsets2):
|
|
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
|
|
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
|
|
return nt1 * nt2
|
|
|
|
self._validate_compile(fn, arg_fn=lambda: (values, offsets, offsets))
|
|
|
|
def test_in_graph_construction_from_input_4(self):
|
|
# The offsets is taken from an NJT input
|
|
def fn(nt, other_values):
|
|
nt2 = torch.nested.nested_tensor_from_jagged(other_values, nt.offsets())
|
|
return nt + nt2
|
|
|
|
values = torch.randn(9, 5, requires_grad=True)
|
|
other_values = torch.randn(9, 5, requires_grad=True)
|
|
offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
|
|
|
|
def arg_fn(values=values, other_values=other_values, offsets=offsets):
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
return nt, other_values
|
|
|
|
self._validate_compile(fn, arg_fn=arg_fn)
|
|
|
|
# Do not specialize on the offsets
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
different_offsets = offsets.clone()
|
|
|
|
def arg_fn(
|
|
values=values, other_values=other_values, offsets=different_offsets
|
|
):
|
|
nt = torch.nested.nested_tensor_from_jagged(values, different_offsets)
|
|
return nt, other_values
|
|
|
|
self._validate_compile(fn, arg_fn=arg_fn)
|
|
|
|
def test_in_graph_construction_from_input_5(self):
|
|
# Construct from lengths instead of offsets
|
|
def fn(values, lengths):
|
|
nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)
|
|
return nt.sin()
|
|
|
|
values = torch.randn(9, 5, requires_grad=True)
|
|
lengths = torch.tensor([2, 4, 3])
|
|
self._validate_compile(fn, arg_fn=lambda: (values, lengths))
|
|
|
|
def test_in_graph_construction_from_input_6(self):
|
|
# Construct with symbolic int.
|
|
def fn(values, offsets, max_seqlen):
|
|
t = torch.nested.nested_tensor_from_jagged(
|
|
values, offsets, max_seqlen=max_seqlen
|
|
)
|
|
return torch.nested.nested_tensor_from_jagged(
|
|
values, t.offsets(), max_seqlen=t._maybe_max_seqlen
|
|
)
|
|
|
|
opt_fn = torch.compile(fn, fullgraph=True, dynamic=True)
|
|
values = torch.randn(10, 5)
|
|
offsets = torch.tensor([0, 2, 4, 7, 10])
|
|
max_seqlen = 5
|
|
|
|
ref = fn(values, offsets, max_seqlen)
|
|
res = opt_fn(values, offsets, max_seqlen)
|
|
self.assertEqualIgnoringNestedInts(ref, res)
|
|
|
|
#
|
|
# Case 2: in-graph construction where offsets are graph intermediates
|
|
#
|
|
def test_in_graph_construction_from_intermediate(self):
|
|
# offsets is an intermediate computed from lengths
|
|
def fn(values, lengths):
|
|
offsets = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)])
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
return (nt * nt2).sin()
|
|
|
|
values = torch.randn(9, 5, requires_grad=True)
|
|
lengths = torch.tensor([2, 4, 3])
|
|
self._validate_compile(fn, arg_fn=lambda: (values, lengths))
|
|
|
|
# Do not specialize on the lengths
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
different_lengths = lengths.clone()
|
|
self._validate_compile(fn, arg_fn=lambda: (values, different_lengths))
|
|
|
|
def test_in_graph_construction_from_intermediate_2(self):
|
|
def fn(values, offsets):
|
|
return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
|
|
|
|
values = torch.randn(10, 5, requires_grad=True)
|
|
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
|
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
|
|
|
|
def test_in_graph_construction_from_intermediate_3(self):
|
|
# Note that due to CSE, clone is not necessarily called twice!
|
|
def fn(values, offsets):
|
|
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
|
|
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets.clone())
|
|
return nt2, nt1
|
|
|
|
values = torch.randn(10, 5, requires_grad=True)
|
|
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
|
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
|
|
|
|
def test_in_graph_construction_from_intermediate_4(self):
|
|
# Shared intermediate (should be same as case #1)
|
|
def fn(values):
|
|
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
values2 = torch.ones_like(values)
|
|
nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets)
|
|
return nt * nt2
|
|
|
|
values = torch.randn(10, 5).requires_grad_(True)
|
|
self._validate_compile(fn, arg_fn=lambda: (values,))
|
|
|
|
# AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
|
|
@unittest.expectedFailure
|
|
def test_in_graph_construction_from_intermediate_5(self):
|
|
# non-shared intermediate
|
|
def fn(values):
|
|
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
values2 = torch.ones_like(values)
|
|
nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets.clone())
|
|
if nt2.shape[1] != nt.shape[1]:
|
|
return nt * 2
|
|
else:
|
|
return nt * 3
|
|
|
|
values = torch.randn(10, 5).requires_grad_(True)
|
|
self._validate_compile(fn, arg_fn=lambda: (values,))
|
|
|
|
#
|
|
# Case 3: in-graph construction where offsets are both direct graph inputs
|
|
# and passed in as part of an NJT's offsets.
|
|
#
|
|
def test_in_graph_construction_mixed(self):
|
|
def fn(nt, values, offsets):
|
|
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
return nt * nt2
|
|
|
|
values = torch.randn(10, 5, requires_grad=True)
|
|
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
|
|
|
def arg_fn(values=values, offsets=offsets):
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
return nt, values, offsets
|
|
|
|
self._validate_compile(fn, arg_fn)
|
|
|
|
# See Note: [Creating symbolic nested int]
|
|
# AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
|
|
@unittest.expectedFailure
|
|
def test_in_graph_construction_mixed_2(self):
|
|
def fn(nt, values, offsets, nt2):
|
|
# Intermediate offsets has ephemeral source
|
|
intermediate_nt = torch.nested.nested_tensor_from_jagged(
|
|
values, offsets.clone()
|
|
)
|
|
# This creates a dynamic shapes neq guard
|
|
if nt2.shape[1] != intermediate_nt.shape[1]:
|
|
# We should always go here.
|
|
nt = nt * 2
|
|
return nt
|
|
|
|
values = torch.randn(10, 5, requires_grad=True)
|
|
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
|
offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
|
|
|
|
def arg_fn(values=values, offsets=offsets, offsets2=offsets2):
|
|
# Values is shared, but it shouldn't matter
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets2)
|
|
return nt, values, offsets, nt2
|
|
|
|
self._validate_compile(fn, arg_fn)
|
|
|
|
def test_in_graph_construction_mixed_3(self):
|
|
# More involved mixed case
|
|
def fn(nt, values, offsets):
|
|
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
|
|
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets)
|
|
return nt1 + nt2 + nt
|
|
|
|
values = torch.randn(9, 5, requires_grad=True)
|
|
offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
|
|
|
|
def arg_fn(values=values, offsets=offsets):
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
return nt, values, offsets
|
|
|
|
self._validate_compile(fn, arg_fn)
|
|
|
|
def test_return_shape(self):
|
|
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
|
|
|
def fn(nt):
|
|
return (nt * 2).shape
|
|
|
|
compiled = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
|
compiled(nt)
|
|
|
|
def test_inference_tensor(self):
|
|
with torch.inference_mode():
|
|
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
|
|
|
def fn(n):
|
|
return n * 2
|
|
|
|
torch.compile(fn, backend="eager")(nt)
|
|
|
|
# TODO: cannot parametrize this test class with device for some reason
|
|
def _test_autograd(self, backend):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
|
|
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
# TODO: Switch to public API when it exists
|
|
nt2, _ = jagged_from_list([a, b, c], nt.offsets())
|
|
|
|
def fn1(nt1, nt2):
|
|
return (nt1 + nt2).sin().cos()
|
|
|
|
compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True)
|
|
out = compiled_f(nt, nt2)
|
|
out_buffer = out.values()
|
|
ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))
|
|
|
|
out_ref = fn1(nt, nt2)
|
|
out_buffer_ref = out_ref.values()
|
|
ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c))
|
|
|
|
self.assertTrue(torch.allclose(ga, ga_ref))
|
|
self.assertTrue(torch.allclose(gb, gb_ref))
|
|
self.assertTrue(torch.allclose(gc, gc_ref))
|
|
|
|
def test_basic_autograd(self):
|
|
self._test_autograd("aot_eager")
|
|
|
|
@requires_cuda_and_triton
|
|
def test_basic_autograd_inductor(self):
|
|
self._test_autograd("inductor")
|
|
|
|
def test_subclass_with_mutation_in_graph(self):
|
|
# In this graph, we have an in-graph mutation, i.e. a mutation that is allowed
|
|
# to remain in the graph. Normally this is allowed, but it's not allowed if
|
|
# the graph handles subclasses at all.
|
|
# Whether the mutation is allowed or not allowed in the graph alters the number
|
|
# of outputs from the forward graph. Previously, a bug in this handling meant
|
|
# that sometimes the expected number and actual number of outputs from the
|
|
# joint graph did not match, causing assertion failures.
|
|
def fn(x, y):
|
|
z = x.sin()
|
|
y.sin_()
|
|
return z.cos(), y.cos()
|
|
|
|
fn_c = torch.compile(fn, backend="inductor")
|
|
|
|
values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)]
|
|
values_copy = [x.detach().clone().requires_grad_(True) for x in values]
|
|
|
|
nt, offsets = jagged_from_list(values, None)
|
|
nt_copy, offsets = jagged_from_list(values_copy, offsets)
|
|
y = torch.rand((4, 8))
|
|
y_copy = y.clone()
|
|
|
|
ret = fn_c(nt, y)[0]
|
|
ref = fn(nt_copy, y_copy)[0]
|
|
|
|
self.assertEqual(ret.values(), ref.values())
|
|
|
|
ret.values().sum().backward()
|
|
ref.values().sum().backward()
|
|
for ref_v, res_v in zip(values_copy, values):
|
|
self.assertEqual(ref_v.grad, res_v.grad)
|
|
|
|
@torch._dynamo.config.patch({"capture_scalar_outputs": True})
|
|
def test_unbind(self):
|
|
# NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0).
|
|
# This causes a recompile later on when it realizes the batch and last dim
|
|
# should not always be equal. To avoid that, we use (3, j0, 5) here.
|
|
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
|
nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None)
|
|
nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None)
|
|
|
|
def fn(x):
|
|
return x.unbind()
|
|
|
|
compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True)
|
|
out = compiled_f(nt)
|
|
|
|
out_ref = fn(nt)
|
|
|
|
# correctness
|
|
self.assertEqual(len(out), len(out_ref))
|
|
for x, x_ref in zip(out, out_ref):
|
|
self.assertTrue(torch.allclose(x, x_ref))
|
|
|
|
# We specialize on the length of offsets, e.g. (1) we recompile if the
|
|
# length of the offsets is different. (2) we don't recompile if the
|
|
# length of the offsets is the same, even if the size of the constituent
|
|
# tensors are different.
|
|
self._check_recompiles(fn, (nt,), (nt2,), False)
|
|
self._check_recompiles(fn, (nt,), (nt3,), True)
|
|
|
|
def test_inline_nested_tensor_from_jagged(self):
|
|
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
|
|
|
def fn(x):
|
|
return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets())
|
|
|
|
torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
|
|
|
|
# The test here: nn.Parameters that are secretly subclasses
|
|
# have a metaclass that overrides __isinstance__,
|
|
# that dynamo needs to respect when it inlines the if statement.
|
|
def test_param_subclass_isinstance_input(self):
|
|
x_inner = torch.randn(16, 16, requires_grad=True)
|
|
x = torch.nn.Parameter(TwoTensor(x_inner, x_inner))
|
|
m = torch.nn.Linear(16, 16)
|
|
m.weight = x
|
|
|
|
def fn():
|
|
if isinstance(m.weight, torch.nn.Parameter):
|
|
return m.weight + 1
|
|
else:
|
|
return m.weight + 2
|
|
|
|
out_ref = fn()
|
|
out_test = torch.compile(fn, backend="aot_eager")()
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
def _input_view_test(self, nt_view_name):
|
|
nt_view = VIEW_TEST_CASES[nt_view_name]()
|
|
|
|
def fn(x):
|
|
return x.sin()
|
|
|
|
out_ref = fn(nt_view)
|
|
torch._dynamo.reset()
|
|
compile_fn = torch.compile(
|
|
fn, fullgraph=True, backend="aot_eager", dynamic=True
|
|
)
|
|
out = compile_fn(nt_view)
|
|
|
|
# Check metadata and values are correct
|
|
self.assertTrue(out.size() == out_ref.size())
|
|
self.assertTrue(out.stride() == out_ref.stride())
|
|
if out.is_nested:
|
|
self.assertTrue(torch.allclose(out.values(), out_ref.values()))
|
|
else:
|
|
self.assertTrue(torch.allclose(out, out_ref))
|
|
|
|
# Check that no upper/lower bound guards are incurred
|
|
def backend(gm, args):
|
|
context = torch._guards.TracingContext.get()
|
|
guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
|
|
|
|
# varies based on the type of view
|
|
guard_str = "\n".join(guards)
|
|
|
|
if nt_view_name == "base_is_nt_False_basic":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s85 - 1, s64)
|
|
Eq(s20, s64)
|
|
Eq(s80 - 1, s77)
|
|
Eq(s72, s71)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_False_leaf_False_False":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s85 - 1, s64)
|
|
Eq(s80 - 1, s77)
|
|
Eq(s72, s71)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_False_leaf_False_True":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s85 - 1, s64)
|
|
Eq(s20, s64)
|
|
Eq(s80 - 1, s77)
|
|
Eq(s72, s71)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_False_leaf_True_False":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s85 - 1, s64)
|
|
Eq(s20, s64)
|
|
Eq(s80 - 1, s77)
|
|
Eq(s72, s71)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_False_leaf_True_True":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s85 - 1, s64)
|
|
Eq(s20, s64)
|
|
Eq(s80 - 1, s77)
|
|
Eq(s72, s71)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_False_obscure":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s85 - 1, s64)
|
|
Eq(s20, s64)
|
|
Eq(s80 - 1, s77)
|
|
Eq(s72, s71)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_True_basic":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s17 - 1, s83)
|
|
Eq(s20, s83)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_True_leaf_False_False":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""Eq(s17 - 1, s83)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_True_leaf_False_True":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s17 - 1, s83)
|
|
Eq(s20, s83)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_True_leaf_True_False":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s17 - 1, s83)
|
|
Eq(s20, s83)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_True_leaf_True_True":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s17 - 1, s83)
|
|
Eq(s20, s83)""",
|
|
)
|
|
elif nt_view_name == "base_is_nt_True_obscure":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s17 - 1, s83)
|
|
Eq(s20, s83)""",
|
|
)
|
|
elif nt_view_name == "dense_subclass_dense_subclass":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s85 - 1, s77)
|
|
Eq(s80 - 1, s78)
|
|
Eq(s72, s71)""",
|
|
)
|
|
elif nt_view_name == "subclass_dense":
|
|
self.assertExpectedInline(
|
|
guard_str,
|
|
"""\
|
|
Eq(s85 - 1, s77)
|
|
Eq(s20, s77)""",
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
return gm
|
|
|
|
torch._dynamo.reset()
|
|
compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
|
|
out = compile_fn(nt_view)
|
|
|
|
@parametrize(
|
|
"nt_view_name",
|
|
[k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"],
|
|
)
|
|
def test_inputs_to_compiled_fn_are_views(self, nt_view_name):
|
|
self._input_view_test(nt_view_name)
|
|
|
|
def test_subclass_gives_static_shapes_when_dynamic_false(self):
|
|
def check_graph(gm, *args):
|
|
first_node_example_val = next(iter(gm.graph.nodes)).meta["example_value"]
|
|
# We compiled with dynamic=False, expect no SymInt sizes on our placeholders
|
|
self.assertTrue(
|
|
all(isinstance(x, int) for x in first_node_example_val.shape)
|
|
)
|
|
return gm
|
|
|
|
@torch.compile(backend=check_graph, dynamic=False)
|
|
def f(x):
|
|
return x + 1
|
|
|
|
x_inner = torch.ones(4)
|
|
x = TwoTensor(x_inner, x_inner)
|
|
x_view = x.view(2, 2)
|
|
out = f(x_view) # noqa: F841
|
|
|
|
# NJT1 -> Dense -> NJT2 -> Dense view
|
|
# During view replay, the Dense -> NJT2 part will construct an intermediate,
|
|
# symbolically-sized NJT that is immediately deconstructed to return the final dense
|
|
# view. To construct this intermediate properly, we need the associated nested int
|
|
# to be symbolic. This view is expected to fail compilation until symbolic nested ints
|
|
# are cached onto fake offsets to solve this problem.
|
|
@unittest.expectedFailure
|
|
def test_subclass_dense_subclass_dense_view(self):
|
|
self._input_view_test("subclass_dense_subclass_dense")
|
|
|
|
|
|
instantiate_parametrized_tests(TestNestedTensor)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|