diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 1b99bd94061e..7dad38355e20 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -972,6 +972,14 @@ class FakeTensorTest(TestCase): self.assertIsInstance(r[0], FakeTensor) self.assertIsInstance(r[1], FakeTensor) + def test_fast_div(self): + mode = FakeTensorMode() + with mode: + x = torch.empty(2, 2, device="cpu", dtype=torch.int32) + from torch._subclasses.fake_impls import get_fast_op_impls + fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor] + y = fast_div(mode, x, 2) + self.assertEqual(y.dtype, torch.float32) instantiate_parametrized_tests(FakeTensorTest) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index bc7bc1ba7f82..9d85bf4c77b3 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -890,7 +890,9 @@ def infer_size(a, b): return tuple(expandedSizes) -def make_fast_binary_impl(slow_ref): +def make_fast_binary_impl( + slow_ref, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT +): def fast_binary_impl(mode, *args, **kwargs): def slow(msg): count_label(f"slow {msg}") @@ -957,7 +959,7 @@ def make_fast_binary_impl(slow_ref): # compute promotion # TODO: we don't need the compute type _, common_dtype = elementwise_dtypes( - *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + *operands, type_promotion_kind=type_promotion_kind ) # check all tensors on same device @@ -1042,7 +1044,10 @@ def get_fast_op_impls(): ) register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] register_fast_op_impl(torch.ops.aten.div.Tensor)( - make_fast_binary_impl(torch._refs.div) + make_fast_binary_impl( + torch._refs.div, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) ) register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach) return FAST_OP_IMPLEMENTATIONS