mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 16:44:54 +08:00
[inductor][ez] properly print Pointwise (#165369)
Previously when we print a ComputedBuffer for reduction, we get something like:
```
ComputedBuffer(name='buf0', layout=FixedLayout('cuda:0', torch.float32, size=[1, 768], stride=[768, 1]), data=Reduction(
'cuda',
torch.float32,
def inner_fn(index, rindex):
_, i1 = index
r0_0 = rindex
tmp0 = ops.load(tangents_1, i1 + 768 * r0_0)
tmp1 = ops.to_dtype(tmp0, torch.float32, src_dtype=torch.bfloat16)
tmp2 = ops.load(primals_1, i1 + 768 * r0_0)
tmp3 = ops.to_dtype(tmp2, torch.float32, src_dtype=torch.bfloat16)
tmp4 = ops.load(rsqrt, r0_0)
tmp5 = tmp3 * tmp4
tmp6 = tmp1 * tmp5
return tmp6
,
```
But if we print a ComputedBuffer for a pointwise, we get something like
```
ComputedBuffer(name='buf2', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32768, 768], stride=[768, 1]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.bfloat16, inner_fn=<function make_pointwise.<locals>.inner.<locals>.inner_fn at 0x7f12922c5bc0>, ranges=[32768, 768]))
```
Note that the inner function str is not printed.
With the change, we get the inner_fn string printed in this case:
```
ComputedBuffer(name='buf2', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32768, 768], stride=[768, 1]), data=Pointwise( 14:42:46 [25/1988]
'cuda',
torch.bfloat16,
def inner_fn(index):
i0, i1 = index
tmp0 = ops.load(tangents_1, i1 + 768 * i0)
tmp1 = ops.to_dtype(tmp0, torch.float32, src_dtype=torch.bfloat16)
tmp2 = ops.load(primals_2, i1)
tmp3 = tmp1 * tmp2
tmp4 = ops.load(rsqrt, i0)
tmp5 = tmp3 * tmp4
tmp6 = ops.load(buf1, i0)
tmp7 = ops.constant(-0.5, torch.float32)
tmp8 = tmp6 * tmp7
tmp9 = ops.load(rsqrt, i0)
tmp10 = tmp9 * tmp9
tmp11 = tmp10 * tmp9
tmp12 = tmp8 * tmp11
tmp13 = ops.constant(0.0013020833333333333, torch.float32)
tmp14 = tmp12 * tmp13
tmp15 = ops.load(primals_1, i1 + 768 * i0)
tmp16 = ops.to_dtype(tmp15, torch.float32, src_dtype=torch.bfloat16)
tmp17 = tmp14 * tmp16
tmp18 = tmp5 + tmp17
tmp19 = ops.load(buf1, i0)
tmp20 = ops.constant(-0.5, torch.float32)
tmp21 = tmp19 * tmp20
tmp22 = ops.load(rsqrt, i0)
tmp23 = tmp22 * tmp22
tmp24 = tmp23 * tmp22
tmp25 = tmp21 * tmp24
tmp26 = ops.constant(0.0013020833333333333, torch.float32)
tmp27 = tmp25 * tmp26
tmp28 = ops.load(primals_1, i1 + 768 * i0)
tmp29 = ops.to_dtype(tmp28, torch.float32, src_dtype=torch.bfloat16)
tmp30 = tmp27 * tmp29
tmp31 = tmp18 + tmp30
tmp32 = ops.to_dtype(tmp31, torch.bfloat16, src_dtype=torch.float32)
return tmp32
,
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165369
Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
5fbf93b774
commit
18b3658df9
@ -1074,6 +1074,11 @@ class Pointwise(Loops):
|
||||
|
||||
return self.inner_fn
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._to_str(("ranges",))
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
def get_reduction_size(self) -> Sequence[sympy.Expr]:
|
||||
return []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user