Fix aten.div type promotion for FakeTensor (#150874)

Summary:
When we divide a FakeTensor by an integer using the fast op implementation, the type promotion should be `ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT` so we get a float when dividing an int FakeTensor by an integer.

```
FAST = get_fast_op_impls()
fast_div = FAST[torch.ops.aten.div.Tensor]
fast_div(fake_tensor, some_int)
```

Test Plan:
```
python test/test_fake_tensor.py -k test_fast_div
```

Differential Revision: D72667430

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150874
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu
2025-04-09 18:52:01 +00:00
committed by PyTorch MergeBot
parent d3a2872c67
commit cfab04d01b
2 changed files with 16 additions and 3 deletions

View File

@ -972,6 +972,14 @@ class FakeTensorTest(TestCase):
self.assertIsInstance(r[0], FakeTensor)
self.assertIsInstance(r[1], FakeTensor)
def test_fast_div(self):
mode = FakeTensorMode()
with mode:
x = torch.empty(2, 2, device="cpu", dtype=torch.int32)
from torch._subclasses.fake_impls import get_fast_op_impls
fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor]
y = fast_div(mode, x, 2)
self.assertEqual(y.dtype, torch.float32)
instantiate_parametrized_tests(FakeTensorTest)

View File

@ -890,7 +890,9 @@ def infer_size(a, b):
return tuple(expandedSizes)
def make_fast_binary_impl(slow_ref):
def make_fast_binary_impl(
slow_ref, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
):
def fast_binary_impl(mode, *args, **kwargs):
def slow(msg):
count_label(f"slow {msg}")
@ -957,7 +959,7 @@ def make_fast_binary_impl(slow_ref):
# compute promotion
# TODO: we don't need the compute type
_, common_dtype = elementwise_dtypes(
*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
*operands, type_promotion_kind=type_promotion_kind
)
# check all tensors on same device
@ -1042,7 +1044,10 @@ def get_fast_op_impls():
)
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
register_fast_op_impl(torch.ops.aten.div.Tensor)(
make_fast_binary_impl(torch._refs.div)
make_fast_binary_impl(
torch._refs.div,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
)
register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach)
return FAST_OP_IMPLEMENTATIONS