add tensor subclass printing support in fx/graph.py (#164403)

it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
    class GraphModule(torch.nn.Module):
        def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164403
Approved by: https://github.com/ezyang
This commit is contained in:
bobrenjc93
2025-10-02 10:16:35 -07:00
committed by PyTorch MergeBot
parent c45d56dd00
commit 8c54101933
2 changed files with 24 additions and 11 deletions

View File

@ -3403,10 +3403,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
norm_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s71: "Sym(s71)", L_nt_: "f64[3, s71, 5]"):
def forward(self, s71: "Sym(s71)", L_nt_: "NestedTensor(f64[3, s71, 5])"):
l_nt_ = L_nt_
add: "f64[3, s71, 5]" = l_nt_ + 2; l_nt_ = None
add: "NestedTensor(f64[3, s71, 5])" = l_nt_ + 2; l_nt_ = None
return (add,)
""", # noqa: B950
)

View File

@ -648,24 +648,37 @@ class CodeGen:
"val",
node.meta.get("tensor_meta", node.meta.get("example_value", None)),
)
def _tensor_annotation(t: torch.Tensor) -> str:
stride = stringify_shape(t.stride()) if include_stride else ""
device = f"{t.device}" if include_device else ""
return (
f"{red(dtype_abbrs[t.dtype])}"
f"{blue(stringify_shape(t.shape))}"
f"{dim_blue(stride)}"
f"{dim_green(device)}"
)
# use string as annotation, to make it valid python code
if isinstance(meta_val, torch.Tensor) and meta_val.layout not in (
torch.sparse_csc,
torch.sparse_csr,
):
stride_annotation = (
f"{stringify_shape(meta_val.stride())}"
if include_stride
else ""
)
device_annotation = f"{meta_val.device}" if include_device else ""
maybe_type_annotation = (
f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}'
f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"'
# Fake tensors cause tests to wobble, so do not custom print them.
is_plain = type(meta_val) is torch.Tensor or isinstance(
meta_val, torch._subclasses.FakeTensor
)
core = _tensor_annotation(meta_val)
if is_plain:
maybe_type_annotation = f': "{core}"'
else:
cls = meta_val.__class__.__name__
maybe_type_annotation = f': "{cls}({core})"'
elif isinstance(meta_val, py_sym_types):
val_str = CodeGen._sym_repr(meta_val)
maybe_type_annotation = f': "Sym({val_str})"'
elif isinstance(meta_val, TensorMetadata):
maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'