Reland "Construct NJT without graph breaks" (#133196)

This reverts commit 154d40ca488e6979ce9c2de89d8a35b53129ebea.

and adds changes from https://github.com/pytorch/pytorch/pull/133061

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133196
Approved by: https://github.com/ezyang
ghstack dependencies: #133145
This commit is contained in:
soulitzer
2024-08-13 17:10:03 -04:00
committed by PyTorch MergeBot
parent f23dbefe52
commit 4af4910b1a
10 changed files with 587 additions and 41 deletions

View File

@ -5189,6 +5189,33 @@ def make_lazy_class(cls):
return cls
# Base TestCase for NT tests; used to define common helpers, etc.
class NestedTensorTestCase(TestCase):
def assertEqualIgnoringNestedInts(self, a, b):
# unbinding NJTs allows us to compare them as essentially equal without
# caring about exact nested int comparison
def _unbind_njts(x):
if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
return x.unbind()
else:
return x
self.assertEqual(pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b))
@contextlib.contextmanager
def branch_nested_state(self):
"""Context manager to branch and restore the nested tensor state."""
nested_tensor_module = torch.nested._internal.nested_tensor
original_tensor_symint_registry = nested_tensor_module._tensor_symint_registry.copy()
original_tensor_id_counter = nested_tensor_module._tensor_id_counter
try:
yield
finally:
nested_tensor_module._tensor_id_counter = original_tensor_id_counter
nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry
@make_lazy_class
class LazyVal:
pass