mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Minor updates for upcast_tensor
Based on comments in https://github.com/pytorch/pytorch/pull/78459 Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/78525 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
d7dd0df22b
commit
fd37d1d870
@ -234,16 +234,15 @@ def normalize_op_input_output2(
|
||||
|
||||
|
||||
# NB: This also upcasts dtype arguments
|
||||
|
||||
|
||||
def upcast_tensor(func, x, dtype=torch.float32):
|
||||
# TODO: handle complex correctly
|
||||
def upcast_tensor(x, dtype=torch.float32):
|
||||
if isinstance(x, Tensor) and x.dtype.is_floating_point:
|
||||
return x.to(dtype=dtype)
|
||||
elif (
|
||||
isinstance(x, torch.dtype)
|
||||
and x in [torch.float16, torch.bfloat16]
|
||||
):
|
||||
return torch.float64
|
||||
return dtype
|
||||
else:
|
||||
return x
|
||||
|
||||
@ -410,7 +409,7 @@ class TestDecomp(TestCase):
|
||||
assert len(real_out) == len(decomp_out)
|
||||
|
||||
if do_relative_check:
|
||||
upcast = partial(upcast_tensor, func, dtype=torch.float64)
|
||||
upcast = partial(upcast_tensor, dtype=torch.float64)
|
||||
real_out_double, _ = tree_flatten(
|
||||
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
||||
)
|
||||
|
Reference in New Issue
Block a user