mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][ez] properly print Pointwise (#165369)
Previously when we print a ComputedBuffer for reduction, we get something like: ``` ComputedBuffer(name='buf0', layout=FixedLayout('cuda:0', torch.float32, size=[1, 768], stride=[768, 1]), data=Reduction( 'cuda', torch.float32, def inner_fn(index, rindex): _, i1 = index r0_0 = rindex tmp0 = ops.load(tangents_1, i1 + 768 * r0_0) tmp1 = ops.to_dtype(tmp0, torch.float32, src_dtype=torch.bfloat16) tmp2 = ops.load(primals_1, i1 + 768 * r0_0) tmp3 = ops.to_dtype(tmp2, torch.float32, src_dtype=torch.bfloat16) tmp4 = ops.load(rsqrt, r0_0) tmp5 = tmp3 * tmp4 tmp6 = tmp1 * tmp5 return tmp6 , ``` But if we print a ComputedBuffer for a pointwise, we get something like ``` ComputedBuffer(name='buf2', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32768, 768], stride=[768, 1]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.bfloat16, inner_fn=<function make_pointwise.<locals>.inner.<locals>.inner_fn at 0x7f12922c5bc0>, ranges=[32768, 768])) ``` Note that the inner function str is not printed. With the change, we get the inner_fn string printed in this case: ``` ComputedBuffer(name='buf2', layout=FixedLayout('cuda:0', torch.bfloat16, size=[32768, 768], stride=[768, 1]), data=Pointwise( 14:42:46 [25/1988] 'cuda', torch.bfloat16, def inner_fn(index): i0, i1 = index tmp0 = ops.load(tangents_1, i1 + 768 * i0) tmp1 = ops.to_dtype(tmp0, torch.float32, src_dtype=torch.bfloat16) tmp2 = ops.load(primals_2, i1) tmp3 = tmp1 * tmp2 tmp4 = ops.load(rsqrt, i0) tmp5 = tmp3 * tmp4 tmp6 = ops.load(buf1, i0) tmp7 = ops.constant(-0.5, torch.float32) tmp8 = tmp6 * tmp7 tmp9 = ops.load(rsqrt, i0) tmp10 = tmp9 * tmp9 tmp11 = tmp10 * tmp9 tmp12 = tmp8 * tmp11 tmp13 = ops.constant(0.0013020833333333333, torch.float32) tmp14 = tmp12 * tmp13 tmp15 = ops.load(primals_1, i1 + 768 * i0) tmp16 = ops.to_dtype(tmp15, torch.float32, src_dtype=torch.bfloat16) tmp17 = tmp14 * tmp16 tmp18 = tmp5 + tmp17 tmp19 = ops.load(buf1, i0) tmp20 = ops.constant(-0.5, torch.float32) tmp21 = tmp19 * tmp20 tmp22 = ops.load(rsqrt, i0) tmp23 = tmp22 * tmp22 tmp24 = tmp23 * tmp22 tmp25 = tmp21 * tmp24 tmp26 = ops.constant(0.0013020833333333333, torch.float32) tmp27 = tmp25 * tmp26 tmp28 = ops.load(primals_1, i1 + 768 * i0) tmp29 = ops.to_dtype(tmp28, torch.float32, src_dtype=torch.bfloat16) tmp30 = tmp27 * tmp29 tmp31 = tmp18 + tmp30 tmp32 = ops.to_dtype(tmp31, torch.bfloat16, src_dtype=torch.float32) return tmp32 , ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165369 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
5fbf93b774
commit
18b3658df9
@ -1074,6 +1074,11 @@ class Pointwise(Loops):
|
||||
|
||||
return self.inner_fn
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self._to_str(("ranges",))
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
def get_reduction_size(self) -> Sequence[sympy.Expr]:
|
||||
return []
|
||||
|
||||
|
Reference in New Issue
Block a user