mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Construct NJT without graph breaks" (#133145)
This reverts commit 911154271309667b55dfb963ec6384bd0048019b. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133145 Approved by: https://github.com/YuqingJ
This commit is contained in:
committed by
PyTorch MergeBot
parent
e890d888d9
commit
05de2b2d0f
@ -5189,33 +5189,6 @@ def make_lazy_class(cls):
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
# Base TestCase for NT tests; used to define common helpers, etc.
|
||||
class NestedTensorTestCase(TestCase):
|
||||
def assertEqualIgnoringNestedInts(self, a, b):
|
||||
# unbinding NJTs allows us to compare them as essentially equal without
|
||||
# caring about exact nested int comparison
|
||||
def _unbind_njts(x):
|
||||
if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
|
||||
return x.unbind()
|
||||
else:
|
||||
return x
|
||||
|
||||
self.assertEqual(pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def branch_nested_state(self):
|
||||
"""Context manager to branch and restore the nested tensor state."""
|
||||
nested_tensor_module = torch.nested._internal.nested_tensor
|
||||
original_tensor_symint_registry = nested_tensor_module._tensor_symint_registry.copy()
|
||||
original_tensor_id_counter = nested_tensor_module._tensor_id_counter
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
nested_tensor_module._tensor_id_counter = original_tensor_id_counter
|
||||
nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry
|
||||
|
||||
|
||||
@make_lazy_class
|
||||
class LazyVal:
|
||||
pass
|
||||
|
Reference in New Issue
Block a user