[minor] use set_default_dtype instead of try and finally (#88295)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88295
Approved by: https://github.com/mruberry
This commit is contained in:
kshitij12345
2022-11-03 19:28:33 +00:00
committed by PyTorch MergeBot
parent f8b73340c8
commit fe3a226d74
4 changed files with 13 additions and 31 deletions

View File

@ -8,7 +8,8 @@ import unittest
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY, skipCUDAMemoryLeakCheckIf
from torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY,
set_default_dtype, skipCUDAMemoryLeakCheckIf)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
@ -130,11 +131,8 @@ class TestPrims(TestCase):
batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)]
shapes = [(), (0,), (1,), (5,)]
try:
# Sets the default dtype to NumPy's default dtype of double
cur_default = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
# Sets the default dtype to NumPy's default dtype of double
with set_default_dtype(torch.double):
# Tested here, as this OP is not currently exposed or tested in ATen
for b, s in product(batches, shapes):
x = make_arg(b + s)
@ -144,8 +142,6 @@ class TestPrims(TestCase):
y_np = scipy.special.cbrt(x_np)
self.assertEqual(y, y_np, exact_device=False)
finally:
torch.set_default_dtype(cur_default)
@onlyCUDA
@skipCUDAIfRocm