mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
enable unit tests (#25963)
Summary: These unit tests pass after landing all the warp size awareness patches. Pull Request resolved: https://github.com/pytorch/pytorch/pull/25963 Differential Revision: D17319124 Pulled By: bddppq fbshipit-source-id: 22f5d5f1ca9c67e66a7ccf983b2d2f889a74e729
This commit is contained in:
committed by
Facebook Github Bot
parent
075adb4d2d
commit
00d967c39d
@ -14,8 +14,7 @@ from torch import sparse
|
||||
from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \
|
||||
ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \
|
||||
CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR
|
||||
from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
|
||||
skipIfRocm
|
||||
from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
@ -285,7 +284,6 @@ class TestOptim(TestCase):
|
||||
[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)]
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
def test_adam(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optim.Adam([weight, bias], lr=1e-3)
|
||||
@ -401,7 +399,6 @@ class TestOptim(TestCase):
|
||||
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4)]
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
def test_adamax(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optim.Adamax([weight, bias], lr=1e-1)
|
||||
@ -426,7 +423,6 @@ class TestOptim(TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
|
||||
optim.RMSprop(None, lr=1e-2, momentum=-1.0)
|
||||
|
||||
@skipIfRocm
|
||||
def test_asgd(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optim.ASGD([weight, bias], lr=1e-3, t0=100)
|
||||
@ -451,7 +447,6 @@ class TestOptim(TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"):
|
||||
optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5))
|
||||
|
||||
@skipIfRocm
|
||||
def test_lbfgs(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optim.LBFGS([weight, bias]),
|
||||
|
||||
Reference in New Issue
Block a user