diff --git a/test/test_decomp.py b/test/test_decomp.py index 7b7bb6222bf5..d33882cf0602 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -234,16 +234,15 @@ def normalize_op_input_output2( # NB: This also upcasts dtype arguments - - -def upcast_tensor(func, x, dtype=torch.float32): +# TODO: handle complex correctly +def upcast_tensor(x, dtype=torch.float32): if isinstance(x, Tensor) and x.dtype.is_floating_point: return x.to(dtype=dtype) elif ( isinstance(x, torch.dtype) and x in [torch.float16, torch.bfloat16] ): - return torch.float64 + return dtype else: return x @@ -410,7 +409,7 @@ class TestDecomp(TestCase): assert len(real_out) == len(decomp_out) if do_relative_check: - upcast = partial(upcast_tensor, func, dtype=torch.float64) + upcast = partial(upcast_tensor, dtype=torch.float64) real_out_double, _ = tree_flatten( func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) )