mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use torch.testing.test_close instead of torch.testing.test_allclose (#164539)
Because torch.testing.test_allclose is deprecated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164539 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
aed66248a0
commit
5743d731c1
@ -39,7 +39,7 @@ class ApplyOverlappedOptimizerTest(unittest.TestCase):
|
||||
with self.subTest(i):
|
||||
_validate_params(
|
||||
[model.parameters() for model in models],
|
||||
torch.testing.assert_allclose,
|
||||
torch.testing.assert_close,
|
||||
)
|
||||
|
||||
for opt in optimizers:
|
||||
@ -77,7 +77,7 @@ class ApplyOverlappedOptimizerTest(unittest.TestCase):
|
||||
model.parameters(),
|
||||
model_with_opt_in_bwd.parameters(),
|
||||
],
|
||||
torch.testing.assert_allclose,
|
||||
torch.testing.assert_close,
|
||||
)
|
||||
|
||||
self._run_training_loop_and_validate(
|
||||
@ -113,10 +113,10 @@ class ApplyOverlappedOptimizerTest(unittest.TestCase):
|
||||
|
||||
for p1, p2 in zip(model_with_hook.parameters(), initial_model.parameters()):
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_allclose(p1, p2)
|
||||
torch.testing.assert_close(p1, p2)
|
||||
|
||||
for p1, p2 in zip(model_no_hook.parameters(), initial_model.parameters()):
|
||||
torch.testing.assert_allclose(p1, p2)
|
||||
torch.testing.assert_close(p1, p2)
|
||||
|
||||
def test_multiple_optim_for_params(self) -> None:
|
||||
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10))
|
||||
|
@ -218,7 +218,7 @@ class TestCompilerBisector(TestCase):
|
||||
torch._dynamo.reset()
|
||||
|
||||
try:
|
||||
torch.testing.assert_allclose(torch.compile(op)(x), op(x))
|
||||
torch.testing.assert_close(torch.compile(op)(x), op(x))
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
@ -115,7 +115,7 @@ class TestInductorConfig(TestCase):
|
||||
for kwargs in checks:
|
||||
torch._dynamo.reset()
|
||||
opt_fn = torch.compile(dummy_fn, **kwargs)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
opt_fn(x), y, msg=f"torch.compile(..., **{kwargs!r}) failed"
|
||||
)
|
||||
|
||||
|
@ -724,7 +724,7 @@ class TestFakeQuantizeOps(TestCase):
|
||||
X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
|
||||
Y_prime = torch.fake_quantize_per_channel_affine(
|
||||
X, scale, zero_point, axis, quant_min, quant_max)
|
||||
torch.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
||||
torch.testing.assert_close(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
|
||||
self.assertTrue(Y.dtype == float_type)
|
||||
|
||||
def test_forward_per_channel_cachemask_cpu(self):
|
||||
|
Reference in New Issue
Block a user