Compare commits

...

25 Commits

Author SHA1 Message Date
6467773bef Fix lint issue 2025-11-12 09:58:32 +00:00
34ba12b754 Move test_gpu definition to test_utils, due to each py file may define it base on the device capability. 2025-11-12 09:58:32 +00:00
91ce63d168 Replace test accelerator flag with test gpu 2025-11-12 09:58:32 +00:00
46d69c75e4 Revert schema test relative changes 2025-11-12 09:58:32 +00:00
1705f24744 Fix typo for the test utils 2025-11-12 09:58:32 +00:00
4bd236dc2c Add space for the intend of not support test case names 2025-11-12 09:58:32 +00:00
6c6e2a0ec3 Replace the xpu and cuda flag 2025-11-12 09:58:32 +00:00
8ef9ce0d75 Fix typo 2025-11-12 09:58:32 +00:00
63142cb3f0 Add has_xpu for the if condition 2025-11-12 09:58:32 +00:00
e28d1eaf29 Change the skip test case method with condition check for the known failed cases 2025-11-12 09:58:32 +00:00
bdff027e2b Fix the lint issue for test_utils.py 2025-11-12 09:58:32 +00:00
4666b764c7 Move the has_xpu, has_cuda, has_gpu defintion into the common_utils module 2025-11-12 09:58:32 +00:00
e8103e4535 Skip the failed test schema cases and opened defect #2297 in xpu ops 2025-11-12 09:58:32 +00:00
7ef461224c Fix lint issue for the device type assignment 2025-11-12 09:58:32 +00:00
f1c072fc3f Add type annotation for the device type value for checkpoint module 2025-11-12 09:58:32 +00:00
1d0cb1c715 remove unused import 2025-11-12 09:58:32 +00:00
95a6fd43c2 Temporary remove hipify test limitation for xpu 2025-11-12 09:58:32 +00:00
e601ee454a Change the default value for the device type in the checkpoint to avoid init the class property device_type during the module init time. 2025-11-12 09:58:32 +00:00
b52efb8b4c Update test/test_utils.py
Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
2025-11-12 09:58:31 +00:00
6b233a6d6e Update test/test_utils.py
Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
2025-11-12 09:58:31 +00:00
9da4b34148 Fix lint issue 2025-11-12 09:58:31 +00:00
48946f28f0 Add and update default device cases. used accelerator to replace the get_device_module for the synchronize and memory allocation method 2025-11-12 09:58:31 +00:00
7d9196fcae Remove redundant has_gpu 2025-11-12 09:58:31 +00:00
ce9f77dfec Remove the g_device_type and replace with device_type 2025-11-12 09:58:31 +00:00
8a2f7701f2 migrated two test files to xpu 2025-11-12 09:58:31 +00:00
2 changed files with 55 additions and 44 deletions

View File

@ -53,8 +53,10 @@ from torch.utils.data import DataLoader
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests # noqa: PLW0127
HAS_CUDA = torch.cuda.is_available()
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
)
TEST_GPU = torch.xpu.is_available() or torch.cuda.is_available()
from torch.testing._internal.common_utils import run_tests, TestCase
@ -302,24 +304,24 @@ class TestCheckpoint(TestCase):
self.assertEqual(grad_with_checkpointing, grad_no_checkpointing)
@unittest.skipIf(not HAS_CUDA, "No CUDA")
def test_checkpoint_rng_cuda(self):
@unittest.skipIf(not TEST_GPU, "No accelerator")
def test_checkpoint_rng_gpu(self):
for _ in range(5):
inp = torch.randn(20000, device="cuda").requires_grad_()
inp = torch.randn(20000, device=device_type).requires_grad_()
phase1 = torch.nn.Dropout()
phase2 = torch.nn.Dropout()
def run_fn(input):
return phase2(input)
state = torch.cuda.get_rng_state()
state = torch.get_device_module(device_type).get_rng_state()
out = phase1(inp)
out = checkpoint(run_fn, out, use_reentrant=True)
out.sum().backward()
grad_with_checkpointing = inp.grad
torch.cuda.set_rng_state(state)
torch.get_device_module(device_type).set_rng_state(state)
inp.grad = None
@ -330,9 +332,9 @@ class TestCheckpoint(TestCase):
self.assertEqual(grad_with_checkpointing, grad_no_checkpointing)
@unittest.skipIf(not HAS_CUDA, "No CUDA")
@unittest.skipIf(not TEST_GPU, "No accelerator")
def test_checkpoint_not_preserve_rng_state_and_without_reentrant(self):
inp = torch.randn(2, device="cuda").requires_grad_()
inp = torch.randn(2, device=device_type).requires_grad_()
layer = torch.nn.Dropout()
def run_fn(input):
@ -435,10 +437,10 @@ class TestCheckpoint(TestCase):
out = checkpoint(run_fn2, input_var, input_var2, use_reentrant=True)
out.sum().backward()
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
@unittest.skipIf(not TEST_GPU, "No accelerator")
def test_checkpointing_without_reentrant_early_free(self):
# I don't know how to check if the temporary saved variable buffer
# get de-allocated directly. So using cuda memory usage as a proxy
# get de-allocated directly. So using GPU memory usage as a proxy
def _do_test(fn, should_free):
stats: list[int] = []
@ -449,8 +451,8 @@ class TestCheckpoint(TestCase):
# emptied at each step)
def hook(_unused):
self.assertEqual(len(stats), idx)
torch.cuda.synchronize()
stats.append(torch.cuda.memory_allocated())
torch.accelerator.synchronize()
stats.append(torch.accelerator.memory_allocated())
if idx > 0:
if should_free:
self.assertLess(stats[idx], stats[idx - 1])
@ -475,7 +477,7 @@ class TestCheckpoint(TestCase):
return stats
x = torch.zeros(10, device="cuda", requires_grad=True)
x = torch.zeros(10, device=device_type, requires_grad=True)
x.grad = torch.zeros_like(x)
# In a regular backward, buffers get eagerly freed
@ -505,8 +507,8 @@ class TestCheckpoint(TestCase):
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_get_device_states_recursive(self):
inp = {
"foo": torch.rand(10, device="cuda:0"),
"bar": [torch.rand(10, device="cuda:1")],
"foo": torch.rand(10, device=f"{device_type}:0"),
"bar": [torch.rand(10, device=f"{device_type}:1")],
}
device_ids, device_states = get_device_states(inp)
self.assertEqual(2, len(device_ids))
@ -522,42 +524,42 @@ class TestCheckpoint(TestCase):
self.assertEqual("meta", device_type)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_infer_device_state_recursive_multi_cuda(self):
# Check that no warning is issued for either cuda:0, cuda:1 or
# cuda:0, cuda:0 cases since they are both the same device type
def test_infer_device_state_recursive_multi_gpu(self):
# Check that no warning is issued for either gpu:0, gpu:1 or
# gpu:0, gpu:0 cases since they are both the same device type
inp = {
"foo": torch.rand(10, device="cuda:0"),
"bar": [torch.rand(10, device="cuda:1")],
"foo": torch.rand(10, device=f"{device_type}:0"),
"bar": [torch.rand(10, device=f"{device_type}:1")],
}
with warnings.catch_warnings():
warnings.simplefilter("error")
device_type = _infer_device_type(inp)
self.assertEqual("cuda", device_type)
_device_type = _infer_device_type(inp)
self.assertEqual(device_type, _device_type)
inp = {
"foo": torch.rand(10, device="cuda:0"),
"bar": [torch.rand(10, device="cuda:0")],
"foo": torch.rand(10, device=f"{device_type}:0"),
"bar": [torch.rand(10, device=f"{device_type}:0")],
}
with warnings.catch_warnings():
warnings.simplefilter("error")
device_type = _infer_device_type(inp)
self.assertEqual("cuda", device_type)
# Check that a warning is issued for cuda:0, meta and that it includes
_device_type = _infer_device_type(inp)
self.assertEqual(device_type, _device_type)
# Check that a warning is issued for gpu:0, meta and that it includes
# device type information
inp = {
"foo": torch.rand(10, device="cuda:0"),
"foo": torch.rand(10, device=f"{device_type}:0"),
"bar": [torch.rand(10, device="meta")],
}
with warnings.catch_warnings(record=True) as w:
device_type = _infer_device_type(inp)
self.assertEqual("cuda", device_type)
_device_type = _infer_device_type(inp)
self.assertEqual(device_type, _device_type)
self.assertEqual(len(w), 1)
warning_msg = str(w[-1].message)
self.assertTrue(
"Tensor arguments, excluding CPU tensors, are detected on at least two types of devices"
in warning_msg
)
self.assertTrue("Device types: ['cuda', 'meta']" in warning_msg)
self.assertTrue("first device type: cuda" in warning_msg)
self.assertTrue(f"Device types: ['{device_type}', 'meta']" in warning_msg)
self.assertTrue(f"first device type: {device_type}" in warning_msg)
class TestDataLoaderUtils(TestCase):
@ -604,7 +606,7 @@ class TestDataLoaderUtils(TestCase):
self.assertEqual(len(list(dataiter)), 1)
@unittest.skip(
"FIXME: Intermittent CUDA out-of-memory error on Windows and time-out under ASAN"
"FIXME: Intermittent GPU out-of-memory error on Windows and time-out under ASAN"
)
def test_multi_keep(self):
dataloader: DataLoader = DataLoader(
@ -861,27 +863,33 @@ class TestDeviceUtils(TestCase):
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_get_default_device_more(self):
try:
torch.set_default_device("cuda")
torch.set_default_device(device_type)
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
torch.set_default_device(None)
torch.set_default_device("cuda")
torch.cuda.set_device("cuda:1")
torch.set_default_device(device_type)
torch.get_device_module(device_type).set_device(f"{device_type}:1")
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
torch.accelerator.set_device_index(1)
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
torch.set_default_device(None)
torch.set_default_device("cuda:1")
torch.set_default_device(f"{device_type}:1")
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
torch.set_default_device(None)
torch.set_default_device("cuda:1")
with torch.device("cuda:0"):
self.assertEqual(torch.get_default_device(), torch.device("cuda", 0))
torch.set_default_device(f"{device_type}:1")
with torch.device(f"{device_type}:0"):
self.assertEqual(
torch.get_default_device(), torch.device(f"{device_type}", 0)
)
torch.set_default_device("cpu")
self.assertEqual(torch.get_default_device(), torch.device("cpu"))
with torch.device("cuda:0"):
self.assertEqual(torch.get_default_device(), torch.device("cuda", 0))
with torch.device(f"{device_type}:0"):
self.assertEqual(
torch.get_default_device(), torch.device(f"{device_type}", 0)
)
self.assertEqual(torch.get_default_device(), torch.device("cpu"))
finally:

View File

@ -106,7 +106,7 @@ class DefaultDeviceType:
to save and restore for recomputation.
"""
_default_device_type = "cuda"
_default_device_type: Optional[str] = None
@staticmethod
def set_device_type(device: str = "cuda") -> None:
@ -126,6 +126,9 @@ class DefaultDeviceType:
Returns:
str: The current default device type.
"""
if not DefaultDeviceType._default_device_type:
DefaultDeviceType._default_device_type = acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
return DefaultDeviceType._default_device_type