mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Construct NJT without graph breaks" (#133145)
This reverts commit 911154271309667b55dfb963ec6384bd0048019b. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133145 Approved by: https://github.com/YuqingJ
This commit is contained in:
committed by
PyTorch MergeBot
parent
e890d888d9
commit
05de2b2d0f
@ -24,7 +24,6 @@ from torch.nested._internal.nested_tensor import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
NestedTensorTestCase,
|
||||
parametrize,
|
||||
subtest,
|
||||
)
|
||||
@ -1701,7 +1700,7 @@ class GraphModule(torch.nn.Module):
|
||||
instantiate_parametrized_tests(SubclassTests)
|
||||
|
||||
|
||||
class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
|
||||
class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
||||
def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
|
||||
return get_jagged_tensor(nested_size, offsets, requires_grad)
|
||||
|
||||
@ -1755,408 +1754,6 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
|
||||
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
||||
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)
|
||||
|
||||
def _validate_compile(self, fn, arg_fn):
|
||||
def _gen_grad_outputs(out_val):
|
||||
if isinstance(out_val, (list, tuple)):
|
||||
return tuple(torch.ones_like(c) for c in out_val)
|
||||
else:
|
||||
return (torch.ones_like(out_val),)
|
||||
|
||||
with self.branch_nested_state():
|
||||
from torch.nested._internal.nested_tensor import _tensor_symint_registry
|
||||
|
||||
# Validate that compilation does not modify eager state
|
||||
registry_before = list(_tensor_symint_registry.items())
|
||||
count_before = torch.nested._internal.nested_tensor._tensor_id_counter
|
||||
|
||||
guards_exported = []
|
||||
guards_failed = []
|
||||
|
||||
def append_guard_export(guards):
|
||||
for g in guards:
|
||||
if g.code_list is not None:
|
||||
guards_exported.append(g.code_list[0])
|
||||
|
||||
def append_guard_fail(guards):
|
||||
guards_failed.extend(guards)
|
||||
|
||||
compiled = torch._dynamo.optimize(
|
||||
nopython=True,
|
||||
backend="aot_eager",
|
||||
guard_export_fn=append_guard_export,
|
||||
guard_fail_fn=append_guard_fail,
|
||||
)(fn)
|
||||
registry_after = list(_tensor_symint_registry.items())
|
||||
count_after = torch.nested._internal.nested_tensor._tensor_id_counter
|
||||
self.assertEqual(registry_before, registry_after)
|
||||
self.assertEqual(count_before, count_after)
|
||||
|
||||
args = arg_fn()
|
||||
compile_out = compiled(*args)
|
||||
compile_grads = []
|
||||
g_args = [arg for arg in args if arg.requires_grad]
|
||||
if len(g_args) > 0:
|
||||
compile_grad_outputs = _gen_grad_outputs(compile_out)
|
||||
compile_grads = torch.autograd.grad(
|
||||
compile_out, inputs=g_args, grad_outputs=compile_grad_outputs
|
||||
)
|
||||
|
||||
with self.branch_nested_state():
|
||||
args = arg_fn()
|
||||
ref_out = fn(*args)
|
||||
ref_grads = []
|
||||
g_args = [arg for arg in args if arg.requires_grad]
|
||||
if len(g_args) > 0:
|
||||
ref_grad_outputs = _gen_grad_outputs(ref_out)
|
||||
ref_grads = torch.autograd.grad(
|
||||
ref_out, inputs=g_args, grad_outputs=ref_grad_outputs
|
||||
)
|
||||
|
||||
# Validate correctness forward
|
||||
if isinstance(compile_out, (list, tuple)):
|
||||
# TODO: Fix assertEqual() to support NJTs so this isn't necessary
|
||||
self.assertEqual(len(compile_out), len(ref_out))
|
||||
for c, r in zip(compile_out, ref_out):
|
||||
self.assertEqualIgnoringNestedInts(c, r)
|
||||
else:
|
||||
self.assertEqualIgnoringNestedInts(compile_out, ref_out)
|
||||
|
||||
# Validate correctness backward
|
||||
for compile_grad, ref_grad in zip(compile_grads, ref_grads):
|
||||
self.assertEqualIgnoringNestedInts(compile_grad, ref_grad)
|
||||
|
||||
return guards_exported, guards_failed
|
||||
|
||||
# Note: [What kind of guards are involved in nested tensor compilation]
|
||||
#
|
||||
# Until we implement UnionFind, dynamic shapes guards are not involved.
|
||||
# we rely only on dynamo's tensor aliasing guards.
|
||||
#
|
||||
# This is possible because dynamo able to generate tensor aliasing guards
|
||||
# not only for the outer tensor, but also for the inner tensor.
|
||||
#
|
||||
# The case where dynamic shapes guards would eventually come into play is
|
||||
# when my inputs are (1) two non-aliased tensors, but (2) declared as
|
||||
# equal using a "trust me assert equal" API.
|
||||
|
||||
# Note: [Compiling nested tensor global state]
|
||||
#
|
||||
# Today there are two pieces of global eager state that NJTs deals with:
|
||||
# - tensor_id_counter: a global counter that assigns unique ids to tensors
|
||||
# - tensor_symint_registry: maps tensor to nested int
|
||||
# - this is used in eager only (we should get rid of this because it is
|
||||
# not necessary to cache nested int in eager)
|
||||
# - during tracing, we DO need to cache nested int, but we do so on
|
||||
# the FakeTensor.
|
||||
#
|
||||
# Ideally we would like to satisfy the following:
|
||||
# - (1) The eager state is not mutated during tracing
|
||||
# - (2) Running the compiled function should mutate the eager state in the
|
||||
# same way that running the eager function would
|
||||
# (a) The global counter should be incremented
|
||||
# (b) The registry is updated in the same way
|
||||
#
|
||||
# Today we can satisfy (1) and (2a) but cannot satisfy (2b)
|
||||
#
|
||||
# Today, (1) is satisfied because we maintain a separate counter during
|
||||
# tracing, and cache nested int on FakeTensor instead of relying on
|
||||
# tensor_symint_registry.
|
||||
#
|
||||
# (2) is cannot be completely satisfied because we trace away the
|
||||
# side-effectful operations (which we can fix this by wrapping the
|
||||
# side-effectful operations in a custom op, and threading through effect
|
||||
# tokens.) The current plan is to do that in the UnionFind impl.
|
||||
#
|
||||
# Interestingly, despite this, the state is mutated in a way that is somewhat
|
||||
# close to what we want, e.g. if I construct a nested tensor using an
|
||||
# offsets in the compiled region and return it, AOTAutograd runtime wrapper
|
||||
# must rewrap the inner->inner graph outputs back into subclass. This
|
||||
# triggers the eager logic to run, updating the counter and registry.
|
||||
#
|
||||
# Notably however, compile differs in two ways from eager:
|
||||
# (1) The order in which the offsets are assigned ids is differnet
|
||||
# the registry would be set in the order the offsets are returned
|
||||
# which is not necessarily the same order as they were constructed.
|
||||
# (2) If a NestedTensor is not returned, then the AOTAutograd wrapping
|
||||
# logic will not be triggered.
|
||||
#
|
||||
# I claim that correctness is not affected by these differences today.
|
||||
# e.g. there is never the case where two distinct offsets silently share
|
||||
# the same id.
|
||||
#
|
||||
# (1) is clearly not a problem, and (2) should only be a problem if
|
||||
# the nested int is returned on its own, without the corresponding NJT
|
||||
# being returned. This is not a problem in the current implementation
|
||||
# because returning only a shape is not supported!
|
||||
|
||||
# Note: [Creating symbolic nested int]
|
||||
#
|
||||
# We must create a symbolic nested int when we construct a nested tensor
|
||||
# from a tensor. There are two main cases:
|
||||
#
|
||||
# 1. The offsets has NOT been used to construct a NJT
|
||||
# - Create a new plain nested int with current val of fake nt id counter
|
||||
# - Increment the fake nt id counter
|
||||
# - Create a new symint with plain nested int as hint
|
||||
# 2. The offsets HAS been used to construct a NJT
|
||||
# - Create a new symint with plain nested int as hint
|
||||
#
|
||||
# More details on case 2:
|
||||
# - During fakification of the offsets, we check the eager registry, and
|
||||
# if the tensor HAS been used to construct a NJT,
|
||||
# we create a symint, with the existing nested int as hint, and cache
|
||||
# it on to the FakeTensor.
|
||||
#
|
||||
# [ Always use ephemeral source ]
|
||||
#
|
||||
# We create the new symint ALWAYS with ephemeral source whether that is
|
||||
# in case (1) or (2) even though we could've had a proper source for case (2).
|
||||
# Using a proper source would enable a few more (edge) cases, but since
|
||||
# we plan to handle things more holistically in the future anyway, we don't
|
||||
# bother doing so today.
|
||||
#
|
||||
# Using an ephemeral source has some consequences. But we are happy if
|
||||
# - We do not silently miss recompiles, e.g. we guard when necessary.
|
||||
# We know that this is true, because dynamo guards alone are already
|
||||
# sufficient.
|
||||
# - We are not producing errors for the cases we care about
|
||||
#
|
||||
# The main case we care about is when we guard that two shapes are equal.
|
||||
# In this case, the replacements logic would simplify away the ephemeral
|
||||
# symbol, and there is no error produced.
|
||||
# The unsupported case is when we guard that two shapes are not equal, in
|
||||
# which, we will try and fail to generate a guard.
|
||||
|
||||
#
|
||||
# Case 1: in-graph construction where the offsets are passed as inputs
|
||||
#
|
||||
def test_in_graph_construction_from_input(self):
|
||||
# The offsets is passed as an input
|
||||
def fn(values, offsets):
|
||||
return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2
|
||||
|
||||
values = torch.randn(10, 5, requires_grad=True)
|
||||
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
||||
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
|
||||
|
||||
# Do not specialize on the offsets
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
different_offsets = torch.tensor([0, 1, 5, 10], dtype=torch.int64)
|
||||
self._validate_compile(fn, arg_fn=lambda: (values, different_offsets))
|
||||
|
||||
def test_in_graph_construction_from_input_2(self):
|
||||
# Construct two NJTs, both are passed as inputs
|
||||
def fn(values, offsets1, offsets2):
|
||||
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets1)
|
||||
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
|
||||
return nt2, nt1
|
||||
|
||||
values = torch.randn(10, 5, requires_grad=True)
|
||||
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
||||
offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
|
||||
# 1. Offsets are different
|
||||
guards_exported, guards_failed = self._validate_compile(
|
||||
fn, arg_fn=lambda: (values, offsets, offsets2)
|
||||
)
|
||||
self.assertEqual(len(guards_failed), 0)
|
||||
self.assertNotIn("L['offsets1'] is L['offsets2']", guards_exported)
|
||||
|
||||
# TODO
|
||||
# 2. Offsets are the same
|
||||
new_guards_exported, _ = self._validate_compile(
|
||||
fn, arg_fn=lambda: (values, offsets, offsets)
|
||||
)
|
||||
self.assertTrue(any("Duplicate tensors found" in g for g in guards_failed))
|
||||
self.assertIn("L['offsets1'] is L['offsets2']", new_guards_exported)
|
||||
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
offsets3 = offsets.clone()
|
||||
self._validate_compile(fn, arg_fn=lambda: (values, offsets3, offsets3))
|
||||
|
||||
# Do a binary op
|
||||
def fn(values, offsets, offsets2):
|
||||
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
|
||||
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2)
|
||||
return nt1 * nt2
|
||||
|
||||
self._validate_compile(fn, arg_fn=lambda: (values, offsets, offsets))
|
||||
|
||||
def test_in_graph_construction_from_input_4(self):
|
||||
# The offsets is taken from an NJT input
|
||||
def fn(nt, other_values):
|
||||
nt2 = torch.nested.nested_tensor_from_jagged(other_values, nt.offsets())
|
||||
return nt + nt2
|
||||
|
||||
values = torch.randn(9, 5, requires_grad=True)
|
||||
other_values = torch.randn(9, 5, requires_grad=True)
|
||||
offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
|
||||
|
||||
def arg_fn(values=values, other_values=other_values, offsets=offsets):
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
return nt, other_values
|
||||
|
||||
self._validate_compile(fn, arg_fn=arg_fn)
|
||||
|
||||
# Do not specialize on the offsets
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
different_offsets = offsets.clone()
|
||||
|
||||
def arg_fn(
|
||||
values=values, other_values=other_values, offsets=different_offsets
|
||||
):
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, different_offsets)
|
||||
return nt, other_values
|
||||
|
||||
self._validate_compile(fn, arg_fn=arg_fn)
|
||||
|
||||
def test_in_graph_construction_from_input_5(self):
|
||||
# Construct from lengths instead of offsets
|
||||
def fn(values, lengths):
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)
|
||||
return nt.sin()
|
||||
|
||||
values = torch.randn(9, 5, requires_grad=True)
|
||||
lengths = torch.tensor([2, 4, 3])
|
||||
self._validate_compile(fn, arg_fn=lambda: (values, lengths))
|
||||
|
||||
#
|
||||
# Case 2: in-graph construction where offsets are graph intermediates
|
||||
#
|
||||
def test_in_graph_construction_from_intermediate(self):
|
||||
# offsets is an intermediate computed from lengths
|
||||
def fn(values, lengths):
|
||||
offsets = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)])
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
return (nt * nt2).sin()
|
||||
|
||||
values = torch.randn(9, 5, requires_grad=True)
|
||||
lengths = torch.tensor([2, 4, 3])
|
||||
self._validate_compile(fn, arg_fn=lambda: (values, lengths))
|
||||
|
||||
# Do not specialize on the lengths
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
different_lengths = lengths.clone()
|
||||
self._validate_compile(fn, arg_fn=lambda: (values, different_lengths))
|
||||
|
||||
def test_in_graph_construction_from_intermediate_2(self):
|
||||
def fn(values, offsets):
|
||||
return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
|
||||
|
||||
values = torch.randn(10, 5, requires_grad=True)
|
||||
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
||||
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
|
||||
|
||||
def test_in_graph_construction_from_intermediate_3(self):
|
||||
# Note that due to CSE, clone is not necessarily called twice!
|
||||
def fn(values, offsets):
|
||||
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone())
|
||||
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets.clone())
|
||||
return nt2, nt1
|
||||
|
||||
values = torch.randn(10, 5, requires_grad=True)
|
||||
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
||||
self._validate_compile(fn, arg_fn=lambda: (values, offsets))
|
||||
|
||||
def test_in_graph_construction_from_intermediate_4(self):
|
||||
# Shared intermediate (should be same as case #1)
|
||||
def fn(values):
|
||||
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
values2 = torch.ones_like(values)
|
||||
nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets)
|
||||
return nt * nt2
|
||||
|
||||
values = torch.randn(10, 5).requires_grad_(True)
|
||||
self._validate_compile(fn, arg_fn=lambda: (values,))
|
||||
|
||||
# AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
|
||||
@unittest.expectedFailure
|
||||
def test_in_graph_construction_from_intermediate_5(self):
|
||||
# non-shared intermediate
|
||||
def fn(values):
|
||||
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
values2 = torch.ones_like(values)
|
||||
nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets.clone())
|
||||
if nt2.shape[1] != nt.shape[1]:
|
||||
return nt * 2
|
||||
else:
|
||||
return nt * 3
|
||||
|
||||
values = torch.randn(10, 5).requires_grad_(True)
|
||||
self._validate_compile(fn, arg_fn=lambda: (values,))
|
||||
|
||||
#
|
||||
# Case 3: in-graph construction where offsets are both direct graph inputs
|
||||
# and passed in as part of an NJT's offsets.
|
||||
#
|
||||
def test_in_graph_construction_mixed(self):
|
||||
def fn(nt, values, offsets):
|
||||
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
return nt * nt2
|
||||
|
||||
values = torch.randn(10, 5, requires_grad=True)
|
||||
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
||||
|
||||
def arg_fn(values=values, offsets=offsets):
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
return nt, values, offsets
|
||||
|
||||
self._validate_compile(fn, arg_fn)
|
||||
|
||||
# See Note: [Creating symbolic nested int]
|
||||
# AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>',
|
||||
@unittest.expectedFailure
|
||||
def test_in_graph_construction_mixed_2(self):
|
||||
def fn(nt, values, offsets, nt2):
|
||||
# Intermediate offsets has ephemeral source
|
||||
intermediate_nt = torch.nested.nested_tensor_from_jagged(
|
||||
values, offsets.clone()
|
||||
)
|
||||
# This creates a dynamic shapes neq guard
|
||||
if nt2.shape[1] != intermediate_nt.shape[1]:
|
||||
# We should always go here.
|
||||
nt = nt * 2
|
||||
return nt
|
||||
|
||||
values = torch.randn(10, 5, requires_grad=True)
|
||||
offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64)
|
||||
offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
|
||||
|
||||
def arg_fn(values=values, offsets=offsets, offsets2=offsets2):
|
||||
# Values is shared, but it shouldn't matter
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
nt2 = torch.nested.nested_tensor_from_jagged(values, offsets2)
|
||||
return nt, values, offsets, nt2
|
||||
|
||||
self._validate_compile(fn, arg_fn)
|
||||
|
||||
def test_in_graph_construction_mixed_3(self):
|
||||
# More involved mixed case
|
||||
def fn(nt, values, offsets):
|
||||
nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets)
|
||||
nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets)
|
||||
return nt1 + nt2 + nt
|
||||
|
||||
values = torch.randn(9, 5, requires_grad=True)
|
||||
offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64)
|
||||
|
||||
def arg_fn(values=values, offsets=offsets):
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
return nt, values, offsets
|
||||
|
||||
self._validate_compile(fn, arg_fn)
|
||||
|
||||
def test_return_shape(self):
|
||||
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
||||
|
||||
def fn(nt):
|
||||
return (nt * 2).shape
|
||||
|
||||
compiled = torch.compile(fn, fullgraph=True, backend="aot_eager")
|
||||
compiled(nt)
|
||||
|
||||
# TODO: cannot parametrize this test class with device for some reason
|
||||
def _test_autograd(self, backend):
|
||||
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
|
||||
@ -2312,8 +1909,8 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
|
||||
guard_str,
|
||||
"""\
|
||||
Eq(s5 - 1, s2)
|
||||
Eq(s12 - 1, s7)
|
||||
Eq(s11, s9)""",
|
||||
Eq(s11 - 1, s6)
|
||||
Eq(s10, s8)""",
|
||||
)
|
||||
elif nt_view_name.startswith("base_is_nt_True"):
|
||||
self.assertExpectedInline(
|
||||
@ -2325,8 +1922,8 @@ Eq(s11, s9)""",
|
||||
guard_str,
|
||||
"""\
|
||||
Eq(s4 - 1, s1)
|
||||
Eq(s13 - 1, s8)
|
||||
Eq(s12, s10)""",
|
||||
Eq(s12 - 1, s7)
|
||||
Eq(s11, s9)""",
|
||||
)
|
||||
return gm
|
||||
|
||||
|
@ -48,17 +48,17 @@ from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
IS_WINDOWS,
|
||||
markDynamoStrictTest,
|
||||
NestedTensorTestCase,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfSlowGradcheckEnv,
|
||||
skipIfTorchDynamo,
|
||||
subtest,
|
||||
TEST_WITH_ROCM,
|
||||
TestCase,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.opinfo.definitions.nested import njt_op_db
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
|
||||
|
||||
|
||||
@ -265,6 +265,20 @@ def convert_nt_to_jagged(nt):
|
||||
return buffer_from_jagged(nt)
|
||||
|
||||
|
||||
# Base TestCase for NT tests; used to define common helpers, etc.
|
||||
class NestedTensorTestCase(TestCase):
|
||||
def assertEqualIgnoringNestedInts(self, a, b):
|
||||
# unbinding NJTs allows us to compare them as essentially equal without
|
||||
# caring about exact nested int comparison
|
||||
def _unbind_njts(x):
|
||||
if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
|
||||
return x.unbind()
|
||||
else:
|
||||
return x
|
||||
|
||||
self.assertEqual(tree_map(_unbind_njts, a), tree_map(_unbind_njts, b))
|
||||
|
||||
|
||||
@markDynamoStrictTest
|
||||
class TestNestedTensor(NestedTensorTestCase):
|
||||
@parametrize("batch_size", [2, 4])
|
||||
|
@ -172,7 +172,6 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
if (fake_mode := detect_fake_mode()) and (shape_env := fake_mode.shape_env):
|
||||
shape_env.pending_fresh_unbacked_symbols.clear()
|
||||
fake_mode.epoch += 1
|
||||
fake_mode.reset_nt_tensor_id_counter()
|
||||
|
||||
if prior_autocast_states != _get_autocast_states():
|
||||
raise RuntimeError(
|
||||
|
@ -531,13 +531,8 @@ class FakeTensorConfig:
|
||||
# Making this a descriptor may seem overly fancy, but actually it's the most
|
||||
# convenient way to make sure we have access to FakeTensor during access,
|
||||
# which is required for testing version counter and epoch validity
|
||||
class SymIntMemoDescriptor:
|
||||
class UnbackedMemoDescriptor:
|
||||
_name: str
|
||||
_is_unbacked: bool
|
||||
|
||||
def __init__(self, *, is_unbacked: Optional[bool] = None):
|
||||
assert is_unbacked is not None
|
||||
self._is_unbacked = is_unbacked
|
||||
|
||||
def __set_name__(self, owner: str, name: str) -> None:
|
||||
self._name = name
|
||||
@ -557,20 +552,20 @@ class SymIntMemoDescriptor:
|
||||
|
||||
def __get__(
|
||||
self, obj: FakeTensor, objtype: Optional[Type[FakeTensor]] = None
|
||||
) -> Optional[torch.SymInt]:
|
||||
) -> Optional[object]:
|
||||
if (r := getattr(obj, self._memo(obj))) is None:
|
||||
return None
|
||||
# Version counter based tracking isn't 100% sound but it's close
|
||||
# enough
|
||||
if getattr(obj, self._memo_vc(obj)) != obj._version or (
|
||||
self._is_unbacked
|
||||
and getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch
|
||||
if (
|
||||
getattr(obj, self._memo_vc(obj)) != obj._version
|
||||
or getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch
|
||||
):
|
||||
setattr(obj, self._memo(obj), None)
|
||||
return None
|
||||
return r
|
||||
|
||||
def __set__(self, obj: FakeTensor, value: Optional[torch.SymInt]) -> None:
|
||||
def __set__(self, obj: FakeTensor, value: Optional[object]) -> None:
|
||||
if value is None:
|
||||
setattr(obj, self._memo(obj), None)
|
||||
setattr(obj, self._memo_vc(obj), None)
|
||||
@ -598,14 +593,9 @@ class FakeTensor(Tensor):
|
||||
# TODO: Generalize this as needed, e.g., into a trie of memos, if
|
||||
# you do something like x[0].item() (x[0] is fresh each time, so
|
||||
# memo mechanism here won't work)
|
||||
nonzero_memo = SymIntMemoDescriptor(is_unbacked=True)
|
||||
item_memo = SymIntMemoDescriptor(is_unbacked=True)
|
||||
unique_memo = SymIntMemoDescriptor(is_unbacked=True)
|
||||
|
||||
# We expect nested_int_memo to be None when an offsets is a graph
|
||||
# intermediate, or an input that has never been associated with a
|
||||
# nested int.
|
||||
nested_int_memo = SymIntMemoDescriptor(is_unbacked=False)
|
||||
nonzero_memo = UnbackedMemoDescriptor()
|
||||
item_memo = UnbackedMemoDescriptor()
|
||||
unique_memo = UnbackedMemoDescriptor()
|
||||
|
||||
# Indicates to our torch_dispatch dispatching infra that
|
||||
# this is an "infra" mode with lower dispatching precedence.
|
||||
@ -703,7 +693,6 @@ class FakeTensor(Tensor):
|
||||
self.nonzero_memo = None
|
||||
self.item_memo = None
|
||||
self.unique_memo = None
|
||||
self.nested_int_memo = None
|
||||
|
||||
if FakeTensorConfig.debug:
|
||||
self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined]
|
||||
@ -874,17 +863,6 @@ class FakeTensor(Tensor):
|
||||
|
||||
return common_device, has_scalar_only_inputs
|
||||
|
||||
def get_nested_int(
|
||||
self,
|
||||
*,
|
||||
coeff: Union[int, torch.SymInt] = 1,
|
||||
) -> torch.SymInt:
|
||||
if self.nested_int_memo is None:
|
||||
self.nested_int_memo = self.fake_mode.create_symbolic_nested_int(
|
||||
nt_tensor_id=None
|
||||
)
|
||||
return self.nested_int_memo * coeff
|
||||
|
||||
# We must handle tolist in a special way for FakeTensors here in the case
|
||||
# where tolist is called from torch dispatch for tensor subclasses.
|
||||
# Ordinarily, if a program calls .tolist compiling still works because there is
|
||||
@ -1085,16 +1063,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
_stack: Optional[str]
|
||||
allow_meta: bool
|
||||
|
||||
# NestedTensor uses a tensor_id_counter to uniquely identify offsets.
|
||||
# This counter is incremented when an offsets is used to create an NJT
|
||||
# for the first time. To avoid mutating eager state if we construct NJT
|
||||
# during tracing, we maintain a separate counter on the FakeTensorMode.
|
||||
# The initial count is set to the current eager tensor_id_counter value
|
||||
# upon initialization, and every time you retrace using the same fake tensor
|
||||
# mode, you should reset the counter to the initial count.
|
||||
nt_tensor_id_counter: int = -1
|
||||
nt_tensor_id_initial_count: int = -1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -1183,16 +1151,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# this is an "infra" mode with lower dispatching precedence.
|
||||
self._mode_key = torch._C._TorchDispatchModeKey.FAKE
|
||||
|
||||
import torch.nested._internal.nested_tensor
|
||||
|
||||
self.nt_tensor_id_initial_count = (
|
||||
torch.nested._internal.nested_tensor._tensor_id_counter
|
||||
)
|
||||
self.nt_tensor_id_counter = self.nt_tensor_id_initial_count
|
||||
|
||||
def reset_nt_tensor_id_counter(self) -> None:
|
||||
self.nt_tensor_id_counter = self.nt_tensor_id_initial_count
|
||||
|
||||
# Typically, there is only one fake tensor mode and you test for it by
|
||||
# doing an isinstance test. However, in some situations, there might be
|
||||
# TWO fake tensor modes. The canonical example of this is exporting
|
||||
@ -1247,8 +1205,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
|
||||
# No-op if FakeTensorMode is already in use
|
||||
def __enter__(self) -> Self:
|
||||
import torch.nested._internal.nested_tensor
|
||||
|
||||
prev_only_lift_cpu_tensors = None
|
||||
if self.avoid_device_init:
|
||||
# See NOTE: [torch.tensor, lift_fresh, and device movement]
|
||||
@ -2144,31 +2100,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
|
||||
return tree_map(wrap, r)
|
||||
|
||||
def create_symbolic_nested_int(
|
||||
self, *, nt_tensor_id: Optional[int] = None
|
||||
) -> torch.SymInt:
|
||||
# See Note: [Creating symbolic nested int]
|
||||
# Returned nested int always has coeff=1; multiply the result by coeff if needed
|
||||
import torch.nested._internal.nested_tensor
|
||||
|
||||
if nt_tensor_id is None:
|
||||
nt_tensor_id = self.nt_tensor_id_counter
|
||||
assert self.enter_stack, "should only called while FakeTensorMode is active"
|
||||
self.nt_tensor_id_counter += 1
|
||||
hint = torch._C._get_nested_int(nt_tensor_id, 1)
|
||||
|
||||
src = torch._dynamo.source.EphemeralSource("intermediate_offsets_or_lengths")
|
||||
assert self.shape_env is not None
|
||||
ret = self.shape_env.create_symintnode(
|
||||
sym=self.shape_env.create_symbol(
|
||||
val=hint,
|
||||
source=src,
|
||||
),
|
||||
hint=hint,
|
||||
source=src,
|
||||
)
|
||||
return ret
|
||||
|
||||
_cpp_meta_supports_symint = ordered_set(
|
||||
aten.empty.memory_format,
|
||||
aten.empty_strided.default,
|
||||
|
@ -743,9 +743,3 @@ class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
|
||||
|
||||
def mark_mutation_hidden_from_autograd(self, tensor) -> None:
|
||||
torch._functionalize_mark_mutation_hidden_from_autograd(tensor)
|
||||
|
||||
|
||||
def mb_unwrap_functional_tensor(tensor: torch.Tensor):
|
||||
if isinstance(tensor, FunctionalTensor):
|
||||
return torch._from_functional_tensor(tensor.elem)
|
||||
return tensor
|
||||
|
@ -305,8 +305,6 @@ class MetaTensorDescriber:
|
||||
}
|
||||
type_v = type(t)
|
||||
|
||||
from torch.nested._internal.nested_tensor import _tensor_symint_registry
|
||||
|
||||
# TODO: Is it important to enable torch.inference_mode before querying
|
||||
# these values?
|
||||
r = MetaTensorDesc(
|
||||
@ -335,11 +333,6 @@ class MetaTensorDescriber:
|
||||
is_parameter=isinstance(t, torch.nn.Parameter),
|
||||
is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v,
|
||||
is_nested=is_nested,
|
||||
nested_int=(
|
||||
_tensor_symint_registry[t].node.nested_int()
|
||||
if t in _tensor_symint_registry
|
||||
else None
|
||||
),
|
||||
is_functional=is_functional,
|
||||
layout=layout,
|
||||
device=t.device,
|
||||
@ -472,10 +465,6 @@ class MetaTensorDesc:
|
||||
is_gradtrackingtensor: bool = False
|
||||
is_view: bool = False
|
||||
is_nested: bool = False
|
||||
# We eagerly symbolicize the associated nested int for e.g. offsets / lengths
|
||||
# metadata if that offsets is already associated with a nested int.
|
||||
# See test_construct_from_jagged_with_input_offsets_mixed_case.
|
||||
nested_int: Optional[int] = None
|
||||
is_traceable_wrapper_subclass: bool = False
|
||||
is_functional: bool = False
|
||||
is_conj: bool = False
|
||||
@ -515,7 +504,6 @@ class MetaTensorDesc:
|
||||
"functorch_stack",
|
||||
"autograd_meta_from",
|
||||
"data",
|
||||
"nested_int",
|
||||
]
|
||||
|
||||
ctx: Optional[object] = None # is_traceable_wrapper_subclass
|
||||
@ -1575,12 +1563,6 @@ class MetaConverter:
|
||||
if t.is_parameter:
|
||||
r._is_param = True
|
||||
|
||||
# See Note: [Creating symbolic nested int]
|
||||
if t.nested_int is not None:
|
||||
r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
|
||||
nt_tensor_id=t.nested_int
|
||||
)
|
||||
|
||||
self.set_tensor_memo(t, r)
|
||||
|
||||
return self.get_tensor_memo(t)
|
||||
|
@ -30,7 +30,6 @@ class FakeTensorProp(torch.fx.Interpreter):
|
||||
mode = FakeTensorMode()
|
||||
self._mode = mode
|
||||
mode.epoch += 1
|
||||
mode.reset_nt_tensor_id_counter()
|
||||
|
||||
def run_node(self, n: Node):
|
||||
from torch.fx.experimental.symbolic_shapes import rebind_unbacked, compute_unbacked_bindings
|
||||
|
@ -13,16 +13,7 @@ _tensor_symint_registry = WeakTensorKeyDictionary()
|
||||
|
||||
|
||||
def get_tensor_symint(tensor, *, coeff=1):
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor
|
||||
|
||||
# NB: Only FakeTensor is associated with a memo
|
||||
tensor = mb_unwrap_functional_tensor(tensor)
|
||||
if isinstance(tensor, FakeTensor):
|
||||
return tensor.get_nested_int(coeff=coeff)
|
||||
|
||||
global _tensor_id_counter
|
||||
|
||||
tensor_symint = _tensor_symint_registry.get(tensor)
|
||||
if tensor_symint is None:
|
||||
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
|
||||
@ -251,7 +242,7 @@ class NestedTensor(torch.Tensor):
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.fx.experimental.symbolic_shapes import has_free_symbols
|
||||
|
||||
# inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen]
|
||||
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5
|
||||
@ -266,14 +257,18 @@ class NestedTensor(torch.Tensor):
|
||||
metadata_cache["min_seqlen"] = min_seqlen_tensor
|
||||
if max_seqlen_tensor is not None:
|
||||
metadata_cache["max_seqlen"] = max_seqlen_tensor
|
||||
|
||||
ragged_idx = meta["ragged_idx"]
|
||||
|
||||
# Alternatively, we could make it the caller's responsibility to
|
||||
# cache it. But this heuristic seems simple enough.
|
||||
# Note that we cannot simply check if is_fake(values) because
|
||||
# during aot autograd, FunctionalTensors are not fake but hold
|
||||
# symbolic sizes.
|
||||
ragged_source = offsets if lengths is None else lengths
|
||||
if isinstance(ragged_source, FakeTensor):
|
||||
if has_free_symbols(ragged_source) or has_free_symbols(values):
|
||||
# Associate offsets or lengths (possibly fake, possibly functionalized)
|
||||
# with the ragged_size.
|
||||
ragged_size = outer_size[ragged_idx]
|
||||
ragged_source.nested_int_memo = ragged_size
|
||||
_tensor_symint_registry[ragged_source] = ragged_size
|
||||
|
||||
return NestedTensor(
|
||||
values,
|
||||
|
@ -511,20 +511,7 @@ def _to_copy_default(func, *args, **kwargs):
|
||||
|
||||
# Copy to a new Python subclass NestedTensor
|
||||
new_offsets = inp._offsets.to(device=new_values.device)
|
||||
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._subclasses.functional_tensor import (
|
||||
FunctionalTensor,
|
||||
mb_unwrap_functional_tensor,
|
||||
)
|
||||
|
||||
if isinstance(new_offsets, (FakeTensor, FunctionalTensor)):
|
||||
# Temporary hack until we have the union find
|
||||
tgt = mb_unwrap_functional_tensor(new_offsets)
|
||||
src = mb_unwrap_functional_tensor(inp._offsets)
|
||||
tgt.nested_int_memo = src.nested_int_memo
|
||||
else:
|
||||
_tensor_symint_registry[new_offsets] = _tensor_symint_registry[inp._offsets]
|
||||
_tensor_symint_registry[new_offsets] = _tensor_symint_registry[inp._offsets]
|
||||
inp_kwargs = extract_kwargs(inp)
|
||||
inp_kwargs["offsets"] = new_offsets
|
||||
|
||||
|
@ -5189,33 +5189,6 @@ def make_lazy_class(cls):
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
# Base TestCase for NT tests; used to define common helpers, etc.
|
||||
class NestedTensorTestCase(TestCase):
|
||||
def assertEqualIgnoringNestedInts(self, a, b):
|
||||
# unbinding NJTs allows us to compare them as essentially equal without
|
||||
# caring about exact nested int comparison
|
||||
def _unbind_njts(x):
|
||||
if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
|
||||
return x.unbind()
|
||||
else:
|
||||
return x
|
||||
|
||||
self.assertEqual(pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def branch_nested_state(self):
|
||||
"""Context manager to branch and restore the nested tensor state."""
|
||||
nested_tensor_module = torch.nested._internal.nested_tensor
|
||||
original_tensor_symint_registry = nested_tensor_module._tensor_symint_registry.copy()
|
||||
original_tensor_id_counter = nested_tensor_module._tensor_id_counter
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
nested_tensor_module._tensor_id_counter = original_tensor_id_counter
|
||||
nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry
|
||||
|
||||
|
||||
@make_lazy_class
|
||||
class LazyVal:
|
||||
pass
|
||||
|
Reference in New Issue
Block a user