mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fix truediv numerics between eager and compile (#164144)"
This reverts commit 68913d8f2a953bdbada4033101b04f6e8d49dabe. Reverted https://github.com/pytorch/pytorch/pull/164144 on behalf of https://github.com/malfet due to It breaks CI again, why was it landed for 3 times in a row without any changes? ([comment](https://github.com/pytorch/pytorch/pull/164144#issuecomment-3390973016))
This commit is contained in:
@ -2580,31 +2580,6 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
|
||||
actual = compiled(*example_inputs)
|
||||
self.assertEqual(actual, correct)
|
||||
|
||||
def test_truediv_numerics_with_eager(self):
|
||||
from decimal import Decimal
|
||||
|
||||
y, x = 7.0, 11.0
|
||||
|
||||
@torch.compile
|
||||
def compiled_divide(x, y):
|
||||
return x / y
|
||||
|
||||
for y_dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
|
||||
for x_dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
]:
|
||||
y_ten = torch.tensor([y], dtype=y_dtype, device="cuda")
|
||||
x_ten = torch.tensor([x], dtype=x_dtype, device="cuda")
|
||||
|
||||
torch._dynamo.reset()
|
||||
compiled_div = Decimal(compiled_divide(x, y_ten).item())
|
||||
eager_div = Decimal((x / y_ten).item())
|
||||
|
||||
self.assertEqual(eager_div, compiled_div)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
@ -1075,15 +1075,7 @@ class TritonOverrides(OpOverrides):
|
||||
|
||||
@staticmethod
|
||||
def truediv(x, y):
|
||||
x_dtype = getattr(x, "dtype", None)
|
||||
y_dtype = getattr(y, "dtype", None)
|
||||
|
||||
if x_dtype == torch.float32 and y_dtype == torch.float32:
|
||||
# x / y in Triton is lowered to div.full which is approx
|
||||
# we want div_rn to adhere with eager
|
||||
out = f"triton.language.div_rn({x}, {y})"
|
||||
else:
|
||||
out = f"({x} / {y})"
|
||||
out = f"({x} / {y})"
|
||||
if low_precision_fp_var(x) or low_precision_fp_var(y):
|
||||
out_dtype = get_dtype_handler().truediv(x, y)
|
||||
if out_dtype in (torch.float16, torch.float32):
|
||||
|
Reference in New Issue
Block a user