mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fix truediv numerics between eager and compile (#164144)"
This reverts commit 68913d8f2a953bdbada4033101b04f6e8d49dabe. Reverted https://github.com/pytorch/pytorch/pull/164144 on behalf of https://github.com/malfet due to It breaks CI again, why was it landed for 3 times in a row without any changes? ([comment](https://github.com/pytorch/pytorch/pull/164144#issuecomment-3390973016))
This commit is contained in:
@ -2580,31 +2580,6 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
|
|||||||
actual = compiled(*example_inputs)
|
actual = compiled(*example_inputs)
|
||||||
self.assertEqual(actual, correct)
|
self.assertEqual(actual, correct)
|
||||||
|
|
||||||
def test_truediv_numerics_with_eager(self):
|
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
y, x = 7.0, 11.0
|
|
||||||
|
|
||||||
@torch.compile
|
|
||||||
def compiled_divide(x, y):
|
|
||||||
return x / y
|
|
||||||
|
|
||||||
for y_dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
|
|
||||||
for x_dtype in [
|
|
||||||
torch.float16,
|
|
||||||
torch.bfloat16,
|
|
||||||
torch.float32,
|
|
||||||
torch.float64,
|
|
||||||
]:
|
|
||||||
y_ten = torch.tensor([y], dtype=y_dtype, device="cuda")
|
|
||||||
x_ten = torch.tensor([x], dtype=x_dtype, device="cuda")
|
|
||||||
|
|
||||||
torch._dynamo.reset()
|
|
||||||
compiled_div = Decimal(compiled_divide(x, y_ten).item())
|
|
||||||
eager_div = Decimal((x / y_ten).item())
|
|
||||||
|
|
||||||
self.assertEqual(eager_div, compiled_div)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._inductor.test_case import run_tests
|
from torch._inductor.test_case import run_tests
|
||||||
|
|||||||
@ -1075,15 +1075,7 @@ class TritonOverrides(OpOverrides):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def truediv(x, y):
|
def truediv(x, y):
|
||||||
x_dtype = getattr(x, "dtype", None)
|
out = f"({x} / {y})"
|
||||||
y_dtype = getattr(y, "dtype", None)
|
|
||||||
|
|
||||||
if x_dtype == torch.float32 and y_dtype == torch.float32:
|
|
||||||
# x / y in Triton is lowered to div.full which is approx
|
|
||||||
# we want div_rn to adhere with eager
|
|
||||||
out = f"triton.language.div_rn({x}, {y})"
|
|
||||||
else:
|
|
||||||
out = f"({x} / {y})"
|
|
||||||
if low_precision_fp_var(x) or low_precision_fp_var(y):
|
if low_precision_fp_var(x) or low_precision_fp_var(y):
|
||||||
out_dtype = get_dtype_handler().truediv(x, y)
|
out_dtype = get_dtype_handler().truediv(x, y)
|
||||||
if out_dtype in (torch.float16, torch.float32):
|
if out_dtype in (torch.float16, torch.float32):
|
||||||
|
|||||||
Reference in New Issue
Block a user