mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 21:47:15 +08:00
[AOTI][debug logger] small fix for intermediate value debugger for jit when arg is not tensor (#149007)
repro:
```
import torch
import torch._inductor.config as config
config.aot_inductor.debug_intermediate_value_printer = "2"
config.aot_inductor.filtered_kernel_names = "triton_poi_fused__to_copy_add_0"
class Model(torch.nn.Module):
def forward(self, x):
x = x.to(torch.float)
return x + 1
model = Model().cuda()
x = torch.randn(10).cuda().to(torch.float8_e4m3fn)
_ = torch.compile(model, fullgraph=True)(x)
print("done")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149007
Approved by: https://github.com/jingsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
c96ed7e6f5
commit
fe01af2242
@ -23,6 +23,9 @@ def _print_debugging_tensor_value_info(msg, arg):
|
||||
# at jit inductor level codegen
|
||||
max_numel_to_print = 64
|
||||
print(msg)
|
||||
if not isinstance(arg, torch.Tensor):
|
||||
print("Value: ", arg)
|
||||
return
|
||||
numel = arg.float().numel()
|
||||
# print the debug printing stats
|
||||
if numel <= max_numel_to_print:
|
||||
|
||||
Reference in New Issue
Block a user