mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix aten.div type promotion for FakeTensor (#150874)
Summary: When we divide a FakeTensor by an integer using the fast op implementation, the type promotion should be `ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT` so we get a float when dividing an int FakeTensor by an integer. ``` FAST = get_fast_op_impls() fast_div = FAST[torch.ops.aten.div.Tensor] fast_div(fake_tensor, some_int) ``` Test Plan: ``` python test/test_fake_tensor.py -k test_fast_div ``` Differential Revision: D72667430 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150874 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
d3a2872c67
commit
cfab04d01b
@ -972,6 +972,14 @@ class FakeTensorTest(TestCase):
|
||||
self.assertIsInstance(r[0], FakeTensor)
|
||||
self.assertIsInstance(r[1], FakeTensor)
|
||||
|
||||
def test_fast_div(self):
|
||||
mode = FakeTensorMode()
|
||||
with mode:
|
||||
x = torch.empty(2, 2, device="cpu", dtype=torch.int32)
|
||||
from torch._subclasses.fake_impls import get_fast_op_impls
|
||||
fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor]
|
||||
y = fast_div(mode, x, 2)
|
||||
self.assertEqual(y.dtype, torch.float32)
|
||||
|
||||
instantiate_parametrized_tests(FakeTensorTest)
|
||||
|
||||
|
@ -890,7 +890,9 @@ def infer_size(a, b):
|
||||
return tuple(expandedSizes)
|
||||
|
||||
|
||||
def make_fast_binary_impl(slow_ref):
|
||||
def make_fast_binary_impl(
|
||||
slow_ref, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
):
|
||||
def fast_binary_impl(mode, *args, **kwargs):
|
||||
def slow(msg):
|
||||
count_label(f"slow {msg}")
|
||||
@ -957,7 +959,7 @@ def make_fast_binary_impl(slow_ref):
|
||||
# compute promotion
|
||||
# TODO: we don't need the compute type
|
||||
_, common_dtype = elementwise_dtypes(
|
||||
*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
*operands, type_promotion_kind=type_promotion_kind
|
||||
)
|
||||
|
||||
# check all tensors on same device
|
||||
@ -1042,7 +1044,10 @@ def get_fast_op_impls():
|
||||
)
|
||||
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
|
||||
register_fast_op_impl(torch.ops.aten.div.Tensor)(
|
||||
make_fast_binary_impl(torch._refs.div)
|
||||
make_fast_binary_impl(
|
||||
torch._refs.div,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
)
|
||||
)
|
||||
register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach)
|
||||
return FAST_OP_IMPLEMENTATIONS
|
||||
|
Reference in New Issue
Block a user