Revert "Construct NJT without graph breaks" (#133145)

This reverts commit 911154271309667b55dfb963ec6384bd0048019b.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133145
Approved by: https://github.com/YuqingJ
This commit is contained in:
soulitzer
2024-08-09 20:01:12 -04:00
committed by PyTorch MergeBot
parent e890d888d9
commit 05de2b2d0f
10 changed files with 40 additions and 569 deletions

View File

@ -24,7 +24,6 @@ from torch.nested._internal.nested_tensor import (
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
NestedTensorTestCase,
parametrize,
subtest,
)
@ -1701,7 +1700,7 @@ class GraphModule(torch.nn.Module):
instantiate_parametrized_tests(SubclassTests)
class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
class TestNestedTensor(torch._dynamo.test_case.TestCase):
def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
return get_jagged_tensor(nested_size, offsets, requires_grad)
@ -1755,408 +1754,6 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
def _validate_compile(self, fn, arg_fn):
def _gen_grad_outputs(out_val):
if isinstance(out_val, (list, tuple)):
return tuple(torch.ones_like(c) for c in out_val)
else:
return (torch.ones_like(out_val),)
with self.branch_nested_state():
from torch.nested._internal.nested_tensor import _tensor_symint_registry
# Validate that compilation does not modify eager state
registry_before = list(_tensor_symint_registry.items())
count_before = torch.nested._internal.nested_tensor._tensor_id_counter
guards_exported = []
guards_failed = []
def append_guard_export(guards):
for g in guards:
if g.code_list is not None:
guards_exported.append(g.code_list[0])
def append_guard_fail(guards):
guards_failed.extend(guards)
compiled = torch._dynamo.optimize(
nopython=True,
backend="aot_eager",
guard_export_fn=append_guard_export,
guard_fail_fn=append_guard_fail,
)(fn)
registry_after = list(_tensor_symint_registry.items())
count_after = torch.nested._internal.nested_tensor._tensor_id_counter
self.assertEqual(registry_before, registry_after)
self.assertEqual(count_before, count_after)
args = arg_fn()
compile_out = compiled(*args)
compile_grads = []
g_args = [arg for arg in args if arg.requires_grad]
if len(g_args) > 0:
compile_grad_outputs = _gen_grad_outputs(compile_out)
compile_grads = torch.autograd.grad(
compile_out, inputs=g_args, grad_outputs=compile_grad_outputs
)
with self.branch_nested_state():
args = arg_fn()
ref_out = fn(*args)
ref_grads = []
g_args = [arg for arg in args if arg.requires_grad]
if len(g_args) > 0:
ref_grad_outputs = _gen_grad_outputs(ref_out)
ref_grads = torch.autograd.grad(
ref_out, inputs=g_args, grad_outputs=ref_grad_outputs
)
# Validate correctness forward
if isinstance(compile_out, (list, tuple)):
# TODO: Fix assertEqual() to support NJTs so this isn't necessary
self.assertEqual(len(compile_out), len(ref_out))
for c, r in zip(compile_out, ref_out):
self.assertEqualIgnoringNestedInts(c, r)
else:
self.assertEqualIgnoringNestedInts(compile_out, ref_out)
# Validate correctness backward
for compile_grad, ref_grad in zip(compile_grads, ref_grads):
self.assertEqualIgnoringNestedInts(compile_grad, ref_grad)
return guards_exported, guards_failed
# Note: [What kind of guards are involved in nested tensor compilation]
#
# Until we implement UnionFind, dynamic shapes guards are not involved.
# we rely only on dynamo's tensor aliasing guards.
#
# This is possible because dynamo able to generate tensor aliasing guards
# not only for the outer tensor, but also for the inner tensor.
#
# The case where dynamic shapes guards would eventually come into play is
# when my inputs are (1) two non-aliased tensors, but (2) declared as
# equal using a "trust me assert equal" API.
# Note: [Compiling nested tensor global state]
#
# Today there are two pieces of global eager state that NJTs deals with:
# - tensor_id_counter: a global counter that assigns unique ids to tensors
# - tensor_symint_registry: maps tensor to nested int
# - this is used in eager only (we should get rid of this because it is
# not necessary to cache nested int in eager)
# - during tracing, we DO need to cache nested int, but we do so on
# the FakeTensor.
#
# Ideally we would like to satisfy the following:
# - (1) The eager state is not mutated during tracing
# - (2) Running the compiled function should mutate the eager state in the
# same way that running the eager function would
# (a) The global counter should be incremented
# (b) The registry is updated in the same way
#
# Today we can satisfy (1) and (2a) but cannot satisfy (2b)
#
# Today, (1) is satisfied because we maintain a separate counter during
# tracing, and cache nested int on FakeTensor instead of relying on
# tensor_symint_registry.
#
# (2) is cannot be completely satisfied because we trace away the
# side-effectful operations (which we can fix this by wrapping the
# side-effectful operations in a custom op, and threading through effect
# tokens.) The current plan is to do that in the UnionFind impl.
#
# Interestingly, despite this, the state is mutated in a way that is somewhat
# close to what we want, e.g. if I construct a nested tensor using an
# offsets in the compiled region and return it, AOTAutograd runtime wrapper
# must rewrap the inner->inner graph outputs back into subclass. This
# triggers the eager logic to run, updating the counter and registry.
#
# Notably however, compile differs in two ways from eager:
# (1) The order in which the offsets are assigned ids is differnet
# the registry would be set in the order the offsets are returned
# which is not necessarily the same order as they were constructed.
# (2) If a NestedTensor is not returned, then the AOTAutograd wrapping
# logic will not be triggered.
#
# I claim that correctness is not affected by these differences today.
# e.g. there is never the case where two distinct offsets silently share
# the same id.
#
# (1) is clearly not a problem, and (2) should only be a problem if
# the nested int is returned on its own, without the corresponding NJT
# being returned. This is not a problem in the current implementation
# because returning only a shape is not supported!
# Note: [Creating symbolic nested int]
#
# We must create a symbolic nested int when we construct a nested tensor
# from a tensor. There are two main cases:
#
# 1. The offsets has NOT been used to construct a NJT
# - Create a new plain nested int with current val of fake nt id counter
# - Increment the fake nt id counter
# - Create a new symint with plain nested int as hint
# 2. The offsets HAS been used to construct a NJT
# - Create a new symint with plain nested int as hint
#
# More details on case 2:
# - During fakification of the offsets, we check the eager registry, and
# if the tensor HAS been used to construct a NJT,
# we create a symint, with the existing nested int as hint, and cache
# it on to the FakeTensor.
#
# [ Always use ephemeral source ]
#
# We create the new symint ALWAYS with ephemeral source whether that is
# in case (1) or (2) even though we could've had a proper source for case (2).
# Using a proper source would enable a few more (edge) cases, but since
# we plan to handle things more holistically in the future anyway, we don't
# bother doing so today.
#
# Using an ephemeral source has some consequences. But we are happy if
# - We do not silently miss recompiles, e.g. we guard when necessary.
# We know that this is true, because dynamo guards alone are already
# sufficient.
# - We are not producing errors for the cases we care about
#
# The main case we care about is when we guard that two shapes are equal.
# In this case, the replacements logic would simplify away the ephemeral
# symbol, and there is no error produced.
# The unsupported case is when we guard that two shapes are not equal, in
# which, we will try and fail to generate a guard.
#
# Case 1: in-graph construction where the offsets are passed as inputs
#
def test_in_graph_construction_from_input(self):
# The offsets is passed as an input
def fn(values, offsets):
return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
# Do not specialize on the offsets
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
different_offsets = torch.tensor([0, 1, 5, 10], dtype=torch.int64)
self._validate_compile(fn, arg_fn=lambda: (values, different_offsets))
def test_in_graph_construction_from_input_2(self):
# Construct two NJTs, both are passed as inputs
def fn(values, offsets1, offsets2):
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets1)
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
return nt2, nt1
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
# 1. Offsets are different
guards_exported, guards_failed = self._validate_compile(
fn, arg_fn=lambda: (values, offsets, offsets2)
)
self.assertEqual(len(guards_failed), 0)
self.assertNotIn("L['offsets1'] is L['offsets2']", guards_exported)
# TODO
# 2. Offsets are the same
new_guards_exported, _ = self._validate_compile(
fn, arg_fn=lambda: (values, offsets, offsets)
)
self.assertTrue(any("Duplicate tensors found" in g for g in guards_failed))
self.assertIn("L['offsets1'] is L['offsets2']", new_guards_exported)
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
offsets3 = offsets.clone()
self._validate_compile(fn, arg_fn=lambda: (values, offsets3, offsets3))
# Do a binary op
def fn(values, offsets, offsets2):
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
return nt1 * nt2
self._validate_compile(fn, arg_fn=lambda: (values, offsets, offsets))
def test_in_graph_construction_from_input_4(self):
# The offsets is taken from an NJT input
def fn(nt, other_values):
nt2 = torch.nested.nested_tensor_from_jagged(other_values, nt.offsets())
return nt + nt2
values = torch.randn(9, 5, requires_grad=True)
other_values = torch.randn(9, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
def arg_fn(values=values, other_values=other_values, offsets=offsets):
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
return nt, other_values
self._validate_compile(fn, arg_fn=arg_fn)
# Do not specialize on the offsets
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
different_offsets = offsets.clone()
def arg_fn(
values=values, other_values=other_values, offsets=different_offsets
):
nt = torch.nested.nested_tensor_from_jagged(values, different_offsets)
return nt, other_values
self._validate_compile(fn, arg_fn=arg_fn)
def test_in_graph_construction_from_input_5(self):
# Construct from lengths instead of offsets
def fn(values, lengths):
nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)
return nt.sin()
values = torch.randn(9, 5, requires_grad=True)
lengths = torch.tensor([2, 4, 3])
self._validate_compile(fn, arg_fn=lambda: (values, lengths))
#
# Case 2: in-graph construction where offsets are graph intermediates
#
def test_in_graph_construction_from_intermediate(self):
# offsets is an intermediate computed from lengths
def fn(values, lengths):
offsets = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)])
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
return (nt * nt2).sin()
values = torch.randn(9, 5, requires_grad=True)
lengths = torch.tensor([2, 4, 3])
self._validate_compile(fn, arg_fn=lambda: (values, lengths))
# Do not specialize on the lengths
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
different_lengths = lengths.clone()
self._validate_compile(fn, arg_fn=lambda: (values, different_lengths))
def test_in_graph_construction_from_intermediate_2(self):
def fn(values, offsets):
return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
def test_in_graph_construction_from_intermediate_3(self):
# Note that due to CSE, clone is not necessarily called twice!
def fn(values, offsets):
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets.clone())
return nt2, nt1
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
def test_in_graph_construction_from_intermediate_4(self):
# Shared intermediate (should be same as case #1)
def fn(values):
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
values2 = torch.ones_like(values)
nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets)
return nt * nt2
values = torch.randn(10, 5).requires_grad_(True)
self._validate_compile(fn, arg_fn=lambda: (values,))
# AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
@unittest.expectedFailure
def test_in_graph_construction_from_intermediate_5(self):
# non-shared intermediate
def fn(values):
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
values2 = torch.ones_like(values)
nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets.clone())
if nt2.shape[1] != nt.shape[1]:
return nt * 2
else:
return nt * 3
values = torch.randn(10, 5).requires_grad_(True)
self._validate_compile(fn, arg_fn=lambda: (values,))
#
# Case 3: in-graph construction where offsets are both direct graph inputs
# and passed in as part of an NJT's offsets.
#
def test_in_graph_construction_mixed(self):
def fn(nt, values, offsets):
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
return nt * nt2
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
def arg_fn(values=values, offsets=offsets):
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
return nt, values, offsets
self._validate_compile(fn, arg_fn)
# See Note: [Creating symbolic nested int]
# AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
@unittest.expectedFailure
def test_in_graph_construction_mixed_2(self):
def fn(nt, values, offsets, nt2):
# Intermediate offsets has ephemeral source
intermediate_nt = torch.nested.nested_tensor_from_jagged(
values, offsets.clone()
)
# This creates a dynamic shapes neq guard
if nt2.shape[1] != intermediate_nt.shape[1]:
# We should always go here.
nt = nt * 2
return nt
values = torch.randn(10, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
def arg_fn(values=values, offsets=offsets, offsets2=offsets2):
# Values is shared, but it shouldn't matter
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets2)
return nt, values, offsets, nt2
self._validate_compile(fn, arg_fn)
def test_in_graph_construction_mixed_3(self):
# More involved mixed case
def fn(nt, values, offsets):
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets)
return nt1 + nt2 + nt
values = torch.randn(9, 5, requires_grad=True)
offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
def arg_fn(values=values, offsets=offsets):
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
return nt, values, offsets
self._validate_compile(fn, arg_fn)
def test_return_shape(self):
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
def fn(nt):
return (nt * 2).shape
compiled = torch.compile(fn, fullgraph=True, backend="aot_eager")
compiled(nt)
# TODO: cannot parametrize this test class with device for some reason
def _test_autograd(self, backend):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
@ -2312,8 +1909,8 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
guard_str,
"""\
Eq(s5 - 1, s2)
Eq(s12 - 1, s7)
Eq(s11, s9)""",
Eq(s11 - 1, s6)
Eq(s10, s8)""",
)
elif nt_view_name.startswith("base_is_nt_True"):
self.assertExpectedInline(
@ -2325,8 +1922,8 @@ Eq(s11, s9)""",
guard_str,
"""\
Eq(s4 - 1, s1)
Eq(s13 - 1, s8)
Eq(s12, s10)""",
Eq(s12 - 1, s7)
Eq(s11, s9)""",
)
return gm

View File

@ -48,17 +48,17 @@ from torch.testing._internal.common_utils import (
IS_FBCODE,
IS_WINDOWS,
markDynamoStrictTest,
NestedTensorTestCase,
parametrize,
run_tests,
skipIfSlowGradcheckEnv,
skipIfTorchDynamo,
subtest,
TEST_WITH_ROCM,
TestCase,
xfailIfTorchDynamo,
)
from torch.testing._internal.opinfo.definitions.nested import njt_op_db
from torch.utils._pytree import tree_flatten
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
@ -265,6 +265,20 @@ def convert_nt_to_jagged(nt):
return buffer_from_jagged(nt)
# Base TestCase for NT tests; used to define common helpers, etc.
class NestedTensorTestCase(TestCase):
def assertEqualIgnoringNestedInts(self, a, b):
# unbinding NJTs allows us to compare them as essentially equal without
# caring about exact nested int comparison
def _unbind_njts(x):
if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
return x.unbind()
else:
return x
self.assertEqual(tree_map(_unbind_njts, a), tree_map(_unbind_njts, b))
@markDynamoStrictTest
class TestNestedTensor(NestedTensorTestCase):
@parametrize("batch_size", [2, 4])

View File

@ -172,7 +172,6 @@ def run_functionalized_fw_and_collect_metadata(
if (fake_mode := detect_fake_mode()) and (shape_env := fake_mode.shape_env):
shape_env.pending_fresh_unbacked_symbols.clear()
fake_mode.epoch += 1
fake_mode.reset_nt_tensor_id_counter()
if prior_autocast_states != _get_autocast_states():
raise RuntimeError(

View File

@ -531,13 +531,8 @@ class FakeTensorConfig:
# Making this a descriptor may seem overly fancy, but actually it's the most
# convenient way to make sure we have access to FakeTensor during access,
# which is required for testing version counter and epoch validity
class SymIntMemoDescriptor:
class UnbackedMemoDescriptor:
_name: str
_is_unbacked: bool
def __init__(self, *, is_unbacked: Optional[bool] = None):
assert is_unbacked is not None
self._is_unbacked = is_unbacked
def __set_name__(self, owner: str, name: str) -> None:
self._name = name
@ -557,20 +552,20 @@ class SymIntMemoDescriptor:
def __get__(
self, obj: FakeTensor, objtype: Optional[Type[FakeTensor]] = None
) -> Optional[torch.SymInt]:
) -> Optional[object]:
if (r := getattr(obj, self._memo(obj))) is None:
return None
# Version counter based tracking isn't 100% sound but it's close
# enough
if getattr(obj, self._memo_vc(obj)) != obj._version or (
self._is_unbacked
and getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch
if (
getattr(obj, self._memo_vc(obj)) != obj._version
or getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch
):
setattr(obj, self._memo(obj), None)
return None
return r
def __set__(self, obj: FakeTensor, value: Optional[torch.SymInt]) -> None:
def __set__(self, obj: FakeTensor, value: Optional[object]) -> None:
if value is None:
setattr(obj, self._memo(obj), None)
setattr(obj, self._memo_vc(obj), None)
@ -598,14 +593,9 @@ class FakeTensor(Tensor):
# TODO: Generalize this as needed, e.g., into a trie of memos, if
# you do something like x[0].item() (x[0] is fresh each time, so
# memo mechanism here won't work)
nonzero_memo = SymIntMemoDescriptor(is_unbacked=True)
item_memo = SymIntMemoDescriptor(is_unbacked=True)
unique_memo = SymIntMemoDescriptor(is_unbacked=True)
# We expect nested_int_memo to be None when an offsets is a graph
# intermediate, or an input that has never been associated with a
# nested int.
nested_int_memo = SymIntMemoDescriptor(is_unbacked=False)
nonzero_memo = UnbackedMemoDescriptor()
item_memo = UnbackedMemoDescriptor()
unique_memo = UnbackedMemoDescriptor()
# Indicates to our torch_dispatch dispatching infra that
# this is an "infra" mode with lower dispatching precedence.
@ -703,7 +693,6 @@ class FakeTensor(Tensor):
self.nonzero_memo = None
self.item_memo = None
self.unique_memo = None
self.nested_int_memo = None
if FakeTensorConfig.debug:
self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined]
@ -874,17 +863,6 @@ class FakeTensor(Tensor):
return common_device, has_scalar_only_inputs
def get_nested_int(
self,
*,
coeff: Union[int, torch.SymInt] = 1,
) -> torch.SymInt:
if self.nested_int_memo is None:
self.nested_int_memo = self.fake_mode.create_symbolic_nested_int(
nt_tensor_id=None
)
return self.nested_int_memo * coeff
# We must handle tolist in a special way for FakeTensors here in the case
# where tolist is called from torch dispatch for tensor subclasses.
# Ordinarily, if a program calls .tolist compiling still works because there is
@ -1085,16 +1063,6 @@ class FakeTensorMode(TorchDispatchMode):
_stack: Optional[str]
allow_meta: bool
# NestedTensor uses a tensor_id_counter to uniquely identify offsets.
# This counter is incremented when an offsets is used to create an NJT
# for the first time. To avoid mutating eager state if we construct NJT
# during tracing, we maintain a separate counter on the FakeTensorMode.
# The initial count is set to the current eager tensor_id_counter value
# upon initialization, and every time you retrace using the same fake tensor
# mode, you should reset the counter to the initial count.
nt_tensor_id_counter: int = -1
nt_tensor_id_initial_count: int = -1
def __init__(
self,
*,
@ -1183,16 +1151,6 @@ class FakeTensorMode(TorchDispatchMode):
# this is an "infra" mode with lower dispatching precedence.
self._mode_key = torch._C._TorchDispatchModeKey.FAKE
import torch.nested._internal.nested_tensor
self.nt_tensor_id_initial_count = (
torch.nested._internal.nested_tensor._tensor_id_counter
)
self.nt_tensor_id_counter = self.nt_tensor_id_initial_count
def reset_nt_tensor_id_counter(self) -> None:
self.nt_tensor_id_counter = self.nt_tensor_id_initial_count
# Typically, there is only one fake tensor mode and you test for it by
# doing an isinstance test. However, in some situations, there might be
# TWO fake tensor modes. The canonical example of this is exporting
@ -1247,8 +1205,6 @@ class FakeTensorMode(TorchDispatchMode):
# No-op if FakeTensorMode is already in use
def __enter__(self) -> Self:
import torch.nested._internal.nested_tensor
prev_only_lift_cpu_tensors = None
if self.avoid_device_init:
# See NOTE: [torch.tensor, lift_fresh, and device movement]
@ -2144,31 +2100,6 @@ class FakeTensorMode(TorchDispatchMode):
return tree_map(wrap, r)
def create_symbolic_nested_int(
self, *, nt_tensor_id: Optional[int] = None
) -> torch.SymInt:
# See Note: [Creating symbolic nested int]
# Returned nested int always has coeff=1; multiply the result by coeff if needed
import torch.nested._internal.nested_tensor
if nt_tensor_id is None:
nt_tensor_id = self.nt_tensor_id_counter
assert self.enter_stack, "should only called while FakeTensorMode is active"
self.nt_tensor_id_counter += 1
hint = torch._C._get_nested_int(nt_tensor_id, 1)
src = torch._dynamo.source.EphemeralSource("intermediate_offsets_or_lengths")
assert self.shape_env is not None
ret = self.shape_env.create_symintnode(
sym=self.shape_env.create_symbol(
val=hint,
source=src,
),
hint=hint,
source=src,
)
return ret
_cpp_meta_supports_symint = ordered_set(
aten.empty.memory_format,
aten.empty_strided.default,

View File

@ -743,9 +743,3 @@ class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
def mark_mutation_hidden_from_autograd(self, tensor) -> None:
torch._functionalize_mark_mutation_hidden_from_autograd(tensor)
def mb_unwrap_functional_tensor(tensor: torch.Tensor):
if isinstance(tensor, FunctionalTensor):
return torch._from_functional_tensor(tensor.elem)
return tensor

View File

@ -305,8 +305,6 @@ class MetaTensorDescriber:
}
type_v = type(t)
from torch.nested._internal.nested_tensor import _tensor_symint_registry
# TODO: Is it important to enable torch.inference_mode before querying
# these values?
r = MetaTensorDesc(
@ -335,11 +333,6 @@ class MetaTensorDescriber:
is_parameter=isinstance(t, torch.nn.Parameter),
is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v,
is_nested=is_nested,
nested_int=(
_tensor_symint_registry[t].node.nested_int()
if t in _tensor_symint_registry
else None
),
is_functional=is_functional,
layout=layout,
device=t.device,
@ -472,10 +465,6 @@ class MetaTensorDesc:
is_gradtrackingtensor: bool = False
is_view: bool = False
is_nested: bool = False
# We eagerly symbolicize the associated nested int for e.g. offsets / lengths
# metadata if that offsets is already associated with a nested int.
# See test_construct_from_jagged_with_input_offsets_mixed_case.
nested_int: Optional[int] = None
is_traceable_wrapper_subclass: bool = False
is_functional: bool = False
is_conj: bool = False
@ -515,7 +504,6 @@ class MetaTensorDesc:
"functorch_stack",
"autograd_meta_from",
"data",
"nested_int",
]
ctx: Optional[object] = None # is_traceable_wrapper_subclass
@ -1575,12 +1563,6 @@ class MetaConverter:
if t.is_parameter:
r._is_param = True
# See Note: [Creating symbolic nested int]
if t.nested_int is not None:
r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
nt_tensor_id=t.nested_int
)
self.set_tensor_memo(t, r)
return self.get_tensor_memo(t)

View File

@ -30,7 +30,6 @@ class FakeTensorProp(torch.fx.Interpreter):
mode = FakeTensorMode()
self._mode = mode
mode.epoch += 1
mode.reset_nt_tensor_id_counter()
def run_node(self, n: Node):
from torch.fx.experimental.symbolic_shapes import rebind_unbacked, compute_unbacked_bindings

View File

@ -13,16 +13,7 @@ _tensor_symint_registry = WeakTensorKeyDictionary()
def get_tensor_symint(tensor, *, coeff=1):
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor
# NB: Only FakeTensor is associated with a memo
tensor = mb_unwrap_functional_tensor(tensor)
if isinstance(tensor, FakeTensor):
return tensor.get_nested_int(coeff=coeff)
global _tensor_id_counter
tensor_symint = _tensor_symint_registry.get(tensor)
if tensor_symint is None:
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
@ -251,7 +242,7 @@ class NestedTensor(torch.Tensor):
@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.symbolic_shapes import has_free_symbols
# inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen]
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5
@ -266,14 +257,18 @@ class NestedTensor(torch.Tensor):
metadata_cache["min_seqlen"] = min_seqlen_tensor
if max_seqlen_tensor is not None:
metadata_cache["max_seqlen"] = max_seqlen_tensor
ragged_idx = meta["ragged_idx"]
# Alternatively, we could make it the caller's responsibility to
# cache it. But this heuristic seems simple enough.
# Note that we cannot simply check if is_fake(values) because
# during aot autograd, FunctionalTensors are not fake but hold
# symbolic sizes.
ragged_source = offsets if lengths is None else lengths
if isinstance(ragged_source, FakeTensor):
if has_free_symbols(ragged_source) or has_free_symbols(values):
# Associate offsets or lengths (possibly fake, possibly functionalized)
# with the ragged_size.
ragged_size = outer_size[ragged_idx]
ragged_source.nested_int_memo = ragged_size
_tensor_symint_registry[ragged_source] = ragged_size
return NestedTensor(
values,

View File

@ -511,20 +511,7 @@ def _to_copy_default(func, *args, **kwargs):
# Copy to a new Python subclass NestedTensor
new_offsets = inp._offsets.to(device=new_values.device)
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import (
FunctionalTensor,
mb_unwrap_functional_tensor,
)
if isinstance(new_offsets, (FakeTensor, FunctionalTensor)):
# Temporary hack until we have the union find
tgt = mb_unwrap_functional_tensor(new_offsets)
src = mb_unwrap_functional_tensor(inp._offsets)
tgt.nested_int_memo = src.nested_int_memo
else:
_tensor_symint_registry[new_offsets] = _tensor_symint_registry[inp._offsets]
_tensor_symint_registry[new_offsets] = _tensor_symint_registry[inp._offsets]
inp_kwargs = extract_kwargs(inp)
inp_kwargs["offsets"] = new_offsets

View File

@ -5189,33 +5189,6 @@ def make_lazy_class(cls):
return cls
# Base TestCase for NT tests; used to define common helpers, etc.
class NestedTensorTestCase(TestCase):
def assertEqualIgnoringNestedInts(self, a, b):
# unbinding NJTs allows us to compare them as essentially equal without
# caring about exact nested int comparison
def _unbind_njts(x):
if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
return x.unbind()
else:
return x
self.assertEqual(pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b))
@contextlib.contextmanager
def branch_nested_state(self):
"""Context manager to branch and restore the nested tensor state."""
nested_tensor_module = torch.nested._internal.nested_tensor
original_tensor_symint_registry = nested_tensor_module._tensor_symint_registry.copy()
original_tensor_id_counter = nested_tensor_module._tensor_id_counter
try:
yield
finally:
nested_tensor_module._tensor_id_counter = original_tensor_id_counter
nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry
@make_lazy_class
class LazyVal:
pass