mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[minor] use set_default_dtype instead of try and finally (#88295)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88295 Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
f8b73340c8
commit
fe3a226d74
@ -8,7 +8,8 @@ import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY, skipCUDAMemoryLeakCheckIf
|
||||
from torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY,
|
||||
set_default_dtype, skipCUDAMemoryLeakCheckIf)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
@ -130,11 +131,8 @@ class TestPrims(TestCase):
|
||||
batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)]
|
||||
shapes = [(), (0,), (1,), (5,)]
|
||||
|
||||
try:
|
||||
# Sets the default dtype to NumPy's default dtype of double
|
||||
cur_default = torch.get_default_dtype()
|
||||
torch.set_default_dtype(torch.double)
|
||||
|
||||
# Sets the default dtype to NumPy's default dtype of double
|
||||
with set_default_dtype(torch.double):
|
||||
# Tested here, as this OP is not currently exposed or tested in ATen
|
||||
for b, s in product(batches, shapes):
|
||||
x = make_arg(b + s)
|
||||
@ -144,8 +142,6 @@ class TestPrims(TestCase):
|
||||
y_np = scipy.special.cbrt(x_np)
|
||||
|
||||
self.assertEqual(y, y_np, exact_device=False)
|
||||
finally:
|
||||
torch.set_default_dtype(cur_default)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
|
Reference in New Issue
Block a user