mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix aten.div type promotion for FakeTensor (#150874)
Summary: When we divide a FakeTensor by an integer using the fast op implementation, the type promotion should be `ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT` so we get a float when dividing an int FakeTensor by an integer. ``` FAST = get_fast_op_impls() fast_div = FAST[torch.ops.aten.div.Tensor] fast_div(fake_tensor, some_int) ``` Test Plan: ``` python test/test_fake_tensor.py -k test_fast_div ``` Differential Revision: D72667430 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150874 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
d3a2872c67
commit
cfab04d01b
@ -972,6 +972,14 @@ class FakeTensorTest(TestCase):
|
||||
self.assertIsInstance(r[0], FakeTensor)
|
||||
self.assertIsInstance(r[1], FakeTensor)
|
||||
|
||||
def test_fast_div(self):
|
||||
mode = FakeTensorMode()
|
||||
with mode:
|
||||
x = torch.empty(2, 2, device="cpu", dtype=torch.int32)
|
||||
from torch._subclasses.fake_impls import get_fast_op_impls
|
||||
fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor]
|
||||
y = fast_div(mode, x, 2)
|
||||
self.assertEqual(y.dtype, torch.float32)
|
||||
|
||||
instantiate_parametrized_tests(FakeTensorTest)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user