mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add tensor subclass printing support in fx/graph.py (#164403)
it was previously quite misleading since it looks like the inputs to the dynamo graph are plain tensors when in reality they are tensor subclasses before ``` class GraphModule(torch.nn.Module): def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"): ``` after ``` class GraphModule(torch.nn.Module): def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"): ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164403 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
c45d56dd00
commit
8c54101933
@ -3403,10 +3403,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase):
|
||||
norm_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s71: "Sym(s71)", L_nt_: "f64[3, s71, 5]"):
|
||||
def forward(self, s71: "Sym(s71)", L_nt_: "NestedTensor(f64[3, s71, 5])"):
|
||||
l_nt_ = L_nt_
|
||||
|
||||
add: "f64[3, s71, 5]" = l_nt_ + 2; l_nt_ = None
|
||||
add: "NestedTensor(f64[3, s71, 5])" = l_nt_ + 2; l_nt_ = None
|
||||
return (add,)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
@ -648,24 +648,37 @@ class CodeGen:
|
||||
"val",
|
||||
node.meta.get("tensor_meta", node.meta.get("example_value", None)),
|
||||
)
|
||||
|
||||
def _tensor_annotation(t: torch.Tensor) -> str:
|
||||
stride = stringify_shape(t.stride()) if include_stride else ""
|
||||
device = f"{t.device}" if include_device else ""
|
||||
return (
|
||||
f"{red(dtype_abbrs[t.dtype])}"
|
||||
f"{blue(stringify_shape(t.shape))}"
|
||||
f"{dim_blue(stride)}"
|
||||
f"{dim_green(device)}"
|
||||
)
|
||||
|
||||
# use string as annotation, to make it valid python code
|
||||
if isinstance(meta_val, torch.Tensor) and meta_val.layout not in (
|
||||
torch.sparse_csc,
|
||||
torch.sparse_csr,
|
||||
):
|
||||
stride_annotation = (
|
||||
f"{stringify_shape(meta_val.stride())}"
|
||||
if include_stride
|
||||
else ""
|
||||
)
|
||||
device_annotation = f"{meta_val.device}" if include_device else ""
|
||||
maybe_type_annotation = (
|
||||
f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}'
|
||||
f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"'
|
||||
# Fake tensors cause tests to wobble, so do not custom print them.
|
||||
is_plain = type(meta_val) is torch.Tensor or isinstance(
|
||||
meta_val, torch._subclasses.FakeTensor
|
||||
)
|
||||
core = _tensor_annotation(meta_val)
|
||||
if is_plain:
|
||||
maybe_type_annotation = f': "{core}"'
|
||||
else:
|
||||
cls = meta_val.__class__.__name__
|
||||
maybe_type_annotation = f': "{cls}({core})"'
|
||||
|
||||
elif isinstance(meta_val, py_sym_types):
|
||||
val_str = CodeGen._sym_repr(meta_val)
|
||||
maybe_type_annotation = f': "Sym({val_str})"'
|
||||
|
||||
elif isinstance(meta_val, TensorMetadata):
|
||||
maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
|
||||
|
||||
|
Reference in New Issue
Block a user