mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove old ROCm version check in tests (#164245)
This PR removes ROCm<6 version checks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164245 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
3912ba3e94
commit
b63bbe1661
@ -36,7 +36,6 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
skipIfRocmVersionLessThan,
|
||||
TEST_CUDA,
|
||||
TEST_WITH_ROCM,
|
||||
TestCase,
|
||||
@ -144,7 +143,6 @@ class TestMatmulCuda(InductorTestCase):
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocmVersionLessThan((5, 2))
|
||||
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
||||
@toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1),
|
||||
torch.bfloat16: xtol(atol=1e-1, rtol=1e-1),
|
||||
@ -158,7 +156,6 @@ class TestMatmulCuda(InductorTestCase):
|
||||
|
||||
@onlyCUDA
|
||||
@xfailIfSM100OrLaterAndCondition(lambda params: params.get('dtype') == torch.bfloat16 and params.get('size') == 10000)
|
||||
@skipIfRocmVersionLessThan((5, 2))
|
||||
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
||||
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
|
||||
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
||||
@ -170,7 +167,6 @@ class TestMatmulCuda(InductorTestCase):
|
||||
self.cublas_addmm(size, dtype, True)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocmVersionLessThan((5, 2))
|
||||
@dtypes(torch.float16)
|
||||
# m == 4 chooses OUTPUT_TYPE reduction on H200
|
||||
# m == 8 chooses OUTPUT_TYPE reduction on A100
|
||||
@ -191,7 +187,6 @@ class TestMatmulCuda(InductorTestCase):
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_precision
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocmVersionLessThan((5, 2))
|
||||
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
||||
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
|
||||
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
||||
|
Reference in New Issue
Block a user