[NJT] Inline through torch.nested.nested_tensor_from_jagged instead of graph break (#124343)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124343
Approved by: https://github.com/jbschlosser
This commit is contained in:
soulitzer
2024-04-19 14:15:28 -04:00
committed by PyTorch MergeBot
parent acbf888a13
commit cf5ca58e7f
3 changed files with 10 additions and 0 deletions

View File

@ -1361,6 +1361,14 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
self._check_recompiles(fn, (nt,), (nt2,), False)
self._check_recompiles(fn, (nt,), (nt3,), True)
def test_inline_nested_tensor_from_jagged(self):
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
def fn(x):
return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets())
torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
def _get_views(self):
# Test all cases with both an NT base and a dense base
# Subclass -> Subclass

View File

@ -634,6 +634,7 @@ class TestExecutionTrace(TestCase):
found_root_node = True
assert found_root_node
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500")
def test_execution_trace_nested_tensor(self):
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()

View File

@ -173,6 +173,7 @@ manual_torch_name_rule_map = {
"torch.nn.Parameter": TorchInGraphFunctionVariable,
"torch._nested_tensor_from_mask": SkipFunctionVariable,
"torch._nested_from_padded": SkipFunctionVariable,
"torch.nested.nested_tensor_from_jagged": UserFunctionVariable,
# symbol operators implemented in Python
"torch.sym_not": TorchInGraphFunctionVariable,
"torch.sym_float": TorchInGraphFunctionVariable,