mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Generalize poison fork logic for each device backend (#144664)
# Motivation Generalize the posion_fork code to make it reusable across different devices. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144664 Approved by: https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
322f883c0c
commit
83bd0b63b5
@ -11,20 +11,18 @@ from torch.testing._internal.common_utils import (
|
||||
IS_ARM64,
|
||||
IS_LINUX,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA,
|
||||
TEST_PRIVATEUSE1,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
|
||||
|
||||
# define TEST_ROCM before changing TEST_CUDA
|
||||
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
|
||||
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
|
||||
# This TestCase should be mutually exclusive with other backends.
|
||||
HAS_CUDA = torch.backends.cuda.is_built()
|
||||
HAS_XPU = torch.xpu._is_compiled()
|
||||
HAS_MPS = torch.backends.mps.is_built()
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM or TEST_XPU,
|
||||
IS_ARM64 or not IS_LINUX or HAS_CUDA or HAS_XPU or HAS_MPS or TEST_PRIVATEUSE1,
|
||||
"Only on linux platform and mutual exclusive to other backends",
|
||||
)
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
|
Reference in New Issue
Block a user