enable test_dispatch_model_tied_weights_memory_with_nested_offload_cpu on xpu (#3569)

* enable test_dispatch_model_tied_weights_memory_with_nested_offload_cpu
case on XPU

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* replace hard-coded torch.cuda w/ device-dependent callings

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* use device agnostic clear_device_cache

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-05-15 17:40:55 +08:00
committed by GitHub
parent cd37bbb629
commit 97c93c4809
2 changed files with 25 additions and 20 deletions

View File

@ -36,7 +36,6 @@ from accelerate.big_modeling import (
from accelerate.hooks import remove_hook_from_submodules from accelerate.hooks import remove_hook_from_submodules
from accelerate.test_utils import ( from accelerate.test_utils import (
require_bnb, require_bnb,
require_cuda,
require_cuda_or_xpu, require_cuda_or_xpu,
require_multi_device, require_multi_device,
require_multi_gpu_or_xpu, require_multi_gpu_or_xpu,
@ -47,6 +46,7 @@ from accelerate.test_utils import (
torch_device, torch_device,
) )
from accelerate.utils import is_hpu_available, offload_state_dict from accelerate.utils import is_hpu_available, offload_state_dict
from accelerate.utils.memory import clear_device_cache
from accelerate.utils.versions import is_torch_version from accelerate.utils.versions import is_torch_version
@ -379,7 +379,7 @@ class BigModelingTester(unittest.TestCase):
torch_accelerator_module = getattr(torch, torch_device_type) torch_accelerator_module = getattr(torch, torch_device_type)
torch_accelerator_module.empty_cache() # Needed in case we run several tests in a row. clear_device_cache() # Needed in case we run several tests in a row.
model = nn.Sequential( model = nn.Sequential(
OrderedDict( OrderedDict(
@ -443,7 +443,7 @@ class BigModelingTester(unittest.TestCase):
# Test that we do not duplicate tied weights at any point during dispatch_model call. # Test that we do not duplicate tied weights at any point during dispatch_model call.
torch_accelerator_module = getattr(torch, torch_device_type) torch_accelerator_module = getattr(torch, torch_device_type)
torch_accelerator_module.empty_cache() # Needed in case we run several tests in a row. clear_device_cache() # Needed in case we run several tests in a row.
class SubModule(torch.nn.Module): class SubModule(torch.nn.Module):
def __init__(self, ref_to_parameter): def __init__(self, ref_to_parameter):
@ -521,7 +521,7 @@ class BigModelingTester(unittest.TestCase):
torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL) torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)
torch_accelerator_module.empty_cache() clear_device_cache()
free_memory_bytes_after_infer = torch_accelerator_module.mem_get_info(torch_device)[0] free_memory_bytes_after_infer = torch_accelerator_module.mem_get_info(torch_device)[0]
@ -536,14 +536,16 @@ class BigModelingTester(unittest.TestCase):
# This test fails because sometimes data_ptr() of compute2.weight is the same as compute1.weight. # This test fails because sometimes data_ptr() of compute2.weight is the same as compute1.weight.
# I checked that the values are not the same but it gives the same address. This does not happen on my local machine. # I checked that the values are not the same but it gives the same address. This does not happen on my local machine.
@require_cuda @require_cuda_or_xpu
@unittest.skip( @unittest.skip(
"Flaky test, we should have enough coverage with test_dispatch_model_tied_weights_memory_with_nested_offload_cpu test" "Flaky test, we should have enough coverage with test_dispatch_model_tied_weights_memory_with_nested_offload_cpu test"
) )
def test_dispatch_model_tied_weights_memory_with_nested_offload_disk(self): def test_dispatch_model_tied_weights_memory_with_nested_offload_disk(self):
# Test that we do not duplicate tied weights at any point during dispatch_model call. # Test that we do not duplicate tied weights at any point during dispatch_model call.
torch.cuda.empty_cache() # Needed in case we run several tests in a row. torch_accelerator_module = getattr(torch, torch_device_type)
clear_device_cache() # Needed in case we run several tests in a row.
class SubModule(torch.nn.Module): class SubModule(torch.nn.Module):
def __init__(self, ref_to_parameter): def __init__(self, ref_to_parameter):
@ -589,27 +591,33 @@ class BigModelingTester(unittest.TestCase):
expected = model(x) expected = model(x)
# Just to initialize CUDA context. # Just to initialize CUDA context.
a = torch.rand(5).to("cuda:0") # noqa: F841 device_0 = f"{torch_device_type}:0"
a = torch.rand(5).to(device_0) # noqa: F841
free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0] free_memory_bytes = torch_accelerator_module.mem_get_info(device_0)[0]
required_memory_bytes = 2 * 5000 * 5000 * (32 // 8) # 200 MB required_memory_bytes = 2 * 5000 * 5000 * (32 // 8) # 200 MB
# Leaving 150 MB of free memory for possible buffers, etc. # Leaving 150 MB of free memory for possible buffers, etc.
n_vals = (free_memory_bytes - required_memory_bytes - int(200e6)) // (32 // 8) n_vals = (free_memory_bytes - required_memory_bytes - int(200e6)) // (32 // 8)
foo = torch.rand(n_vals, device="cuda:0") # noqa: F841 foo = torch.rand(n_vals, device=device_0) # noqa: F841
free_memory_bytes_before_dispatch = torch.cuda.mem_get_info("cuda:0")[0] free_memory_bytes_before_dispatch = torch_accelerator_module.mem_get_info(device_0)[0]
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
dispatch_model(model, device_map, offload_dir=tmp_dir) dispatch_model(model, device_map, offload_dir=tmp_dir)
free_memory_bytes_after_dispatch = torch.cuda.mem_get_info("cuda:0")[0] free_memory_bytes_after_dispatch = torch_accelerator_module.mem_get_info(device_0)[0]
assert (free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130 assert (free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130
oom_error = (
torch.OutOfMemoryError
if hasattr(torch, "OutOfMemoryError")
else torch_accelerator_module.OutOfMemoryError
)
with torch.no_grad(): with torch.no_grad():
try: try:
output = model(x) output = model(x)
except torch.cuda.OutOfMemoryError as e: except oom_error as e:
raise torch.cuda.OutOfMemoryError( raise oom_error(
f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_disk. {e}" f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_disk. {e}"
) )
except Exception as e: except Exception as e:
@ -617,9 +625,9 @@ class BigModelingTester(unittest.TestCase):
torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL) torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)
torch.cuda.empty_cache() clear_device_cache()
free_memory_bytes_after_infer = torch.cuda.mem_get_info("cuda:0")[0] free_memory_bytes_after_infer = torch_accelerator_module.mem_get_info(device_0)[0]
# Check that we have no more references on GPU for the offloaded tied weight. # Check that we have no more references on GPU for the offloaded tied weight.
n_non_empty = 0 n_non_empty = 0

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import tempfile import tempfile
import unittest import unittest
@ -543,8 +542,7 @@ class MixedInt8LoaddedModelTest(unittest.TestCase):
del self.model_fp16 del self.model_fp16
del self.model_8bit del self.model_8bit
gc.collect() clear_device_cache(garbage_collection=True)
torch.cuda.empty_cache()
def test_memory_footprint(self): def test_memory_footprint(self):
r""" r"""
@ -663,8 +661,7 @@ class Bnb4BitEmptyModelTest(unittest.TestCase):
del self.model_fp16 del self.model_fp16
del self.model_4bit del self.model_4bit
gc.collect() clear_device_cache(garbage_collection=True)
torch.cuda.empty_cache()
def test_memory_footprint(self): def test_memory_footprint(self):
r""" r"""