Minor updates for upcast_tensor

Based on comments in https://github.com/pytorch/pytorch/pull/78459

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78525

Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2022-05-31 14:05:49 -07:00
committed by PyTorch MergeBot
parent d7dd0df22b
commit fd37d1d870

View File

@ -234,16 +234,15 @@ def normalize_op_input_output2(
# NB: This also upcasts dtype arguments
def upcast_tensor(func, x, dtype=torch.float32):
# TODO: handle complex correctly
def upcast_tensor(x, dtype=torch.float32):
if isinstance(x, Tensor) and x.dtype.is_floating_point:
return x.to(dtype=dtype)
elif (
isinstance(x, torch.dtype)
and x in [torch.float16, torch.bfloat16]
):
return torch.float64
return dtype
else:
return x
@ -410,7 +409,7 @@ class TestDecomp(TestCase):
assert len(real_out) == len(decomp_out)
if do_relative_check:
upcast = partial(upcast_tensor, func, dtype=torch.float64)
upcast = partial(upcast_tensor, dtype=torch.float64)
real_out_double, _ = tree_flatten(
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)