[AOTI][debug logger] small fix for intermediate value debugger for jit when arg is not tensor (#149007)

repro:
```
import torch
import torch._inductor.config as config

config.aot_inductor.debug_intermediate_value_printer = "2"
config.aot_inductor.filtered_kernel_names = "triton_poi_fused__to_copy_add_0"

class Model(torch.nn.Module):
    def forward(self, x):
        x = x.to(torch.float)
        return x + 1

model = Model().cuda()
x = torch.randn(10).cuda().to(torch.float8_e4m3fn)
_ = torch.compile(model, fullgraph=True)(x)

print("done")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149007
Approved by: https://github.com/jingsh
This commit is contained in:
henrylhtsang
2025-03-11 16:04:25 -07:00
committed by PyTorch MergeBot
parent c96ed7e6f5
commit fe01af2242

View File

@ -23,6 +23,9 @@ def _print_debugging_tensor_value_info(msg, arg):
# at jit inductor level codegen
max_numel_to_print = 64
print(msg)
if not isinstance(arg, torch.Tensor):
print("Value: ", arg)
return
numel = arg.float().numel()
# print the debug printing stats
if numel <= max_numel_to_print: