mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[NJT] Inline through torch.nested.nested_tensor_from_jagged instead of graph break (#124343)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124343 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
acbf888a13
commit
cf5ca58e7f
@ -1361,6 +1361,14 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
||||
self._check_recompiles(fn, (nt,), (nt2,), False)
|
||||
self._check_recompiles(fn, (nt,), (nt3,), True)
|
||||
|
||||
def test_inline_nested_tensor_from_jagged(self):
|
||||
nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None)
|
||||
|
||||
def fn(x):
|
||||
return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets())
|
||||
|
||||
torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
|
||||
|
||||
def _get_views(self):
|
||||
# Test all cases with both an NT base and a dense base
|
||||
# Subclass -> Subclass
|
||||
|
@ -634,6 +634,7 @@ class TestExecutionTrace(TestCase):
|
||||
found_root_node = True
|
||||
assert found_root_node
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500")
|
||||
def test_execution_trace_nested_tensor(self):
|
||||
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
|
||||
fp.close()
|
||||
|
@ -173,6 +173,7 @@ manual_torch_name_rule_map = {
|
||||
"torch.nn.Parameter": TorchInGraphFunctionVariable,
|
||||
"torch._nested_tensor_from_mask": SkipFunctionVariable,
|
||||
"torch._nested_from_padded": SkipFunctionVariable,
|
||||
"torch.nested.nested_tensor_from_jagged": UserFunctionVariable,
|
||||
# symbol operators implemented in Python
|
||||
"torch.sym_not": TorchInGraphFunctionVariable,
|
||||
"torch.sym_float": TorchInGraphFunctionVariable,
|
||||
|
Reference in New Issue
Block a user