Files
accelerate/tests/test_multidevice.py
Yao Matrix fe795fd324 switch XPU ccl backend to torch-builtin xccl in test_zero3_integration (#3773)
* switch XPU ccl backend to torch-builtin xccl in test_zero3_integration
remove xpu workaround in RegressionModel, we are OK now
rename test_multigpu to test_multidevice to reflect the fact

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* fix ci issues

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* xx

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
2025-09-18 11:50:32 +02:00

175 lines
6.6 KiB
Python

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
from unittest import skip
import torch
from accelerate import Accelerator
from accelerate.big_modeling import dispatch_model
from accelerate.test_utils import (
DEFAULT_LAUNCH_COMMAND,
assert_exception,
device_count,
execute_subprocess_async,
get_launch_command,
path_in_accelerate_package,
require_huggingface_suite,
require_multi_device,
require_non_torch_xla,
require_pippy,
require_torchvision,
run_first,
torch_device,
)
from accelerate.utils import is_hpu_available, patch_environment
class MultiDeviceTester(unittest.TestCase):
test_file_path = path_in_accelerate_package("test_utils", "scripts", "test_script.py")
data_loop_file_path = path_in_accelerate_package("test_utils", "scripts", "test_distributed_data_loop.py")
operation_file_path = path_in_accelerate_package("test_utils", "scripts", "test_ops.py")
pippy_file_path = path_in_accelerate_package("test_utils", "scripts", "external_deps", "test_pippy.py")
merge_weights_file_path = path_in_accelerate_package("test_utils", "scripts", "test_merge_weights.py")
@run_first
@require_multi_device
def test_multi_device(self):
print(f"Found {device_count} {torch_device} devices.")
cmd = DEFAULT_LAUNCH_COMMAND + [self.test_file_path]
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd)
@run_first
@require_multi_device
def test_multi_device_ops(self):
print(f"Found {device_count} {torch_device} devices.")
cmd = DEFAULT_LAUNCH_COMMAND + [self.operation_file_path]
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd)
@run_first
@require_multi_device
def test_pad_across_processes(self):
print(f"Found {device_count} {torch_device} devices.")
cmd = DEFAULT_LAUNCH_COMMAND + [inspect.getfile(self.__class__)]
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd)
@run_first
@require_multi_device
def test_multi_device_merge_fsdp_weights(self):
print(f"Found {device_count} {torch_device} devices.")
cmd = DEFAULT_LAUNCH_COMMAND + [self.merge_weights_file_path]
env_kwargs = dict(omp_num_threads=1)
with patch_environment(**env_kwargs):
execute_subprocess_async(cmd)
@run_first
@require_non_torch_xla
@require_multi_device
def test_distributed_data_loop(self):
"""
This TestCase checks the behaviour that occurs during distributed training or evaluation,
when the batch size does not evenly divide the dataset size.
"""
print(f"Found {device_count} devices, using 2 devices only")
cmd = get_launch_command(num_processes=2) + [self.data_loop_file_path]
env_kwargs = dict(omp_num_threads=1)
if torch_device == "xpu":
env_kwargs.update(ze_affinity_mask="0,1")
elif torch_device == "npu":
env_kwargs.update(ascend_rt_visible_devices="0,1")
elif torch_device == "mlu":
env_kwargs.update(mlu_visible_devices="0,1")
elif torch_device == "sdaa":
env_kwargs.update(sdaa_visible_devices="0,1")
else:
env_kwargs.update(cuda_visible_devices="0,1")
with patch_environment(**env_kwargs):
execute_subprocess_async(cmd)
@run_first
@require_pippy
@require_torchvision
@require_multi_device
@require_huggingface_suite
@skip("Will soon deprecate pippy")
def test_pippy(self):
"""
Checks the integration with the pippy framework
"""
print(f"Found {device_count} {torch_device} devices")
cmd = get_launch_command(multi_gpu=True, num_processes=device_count) + [self.pippy_file_path]
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd)
if __name__ == "__main__":
accelerator = Accelerator()
shape = (accelerator.state.process_index + 2, 10)
tensor = torch.randint(0, 10, shape).to(accelerator.device)
error_msg = ""
tensor1 = accelerator.pad_across_processes(tensor)
if tensor1.shape[0] != accelerator.state.num_processes + 1:
error_msg += f"Found shape {tensor1.shape} but should have {accelerator.state.num_processes + 1} at dim 0."
index = accelerator.state.process_index + 2
if not torch.equal(tensor1[:index], tensor):
error_msg += "Tensors have different values."
if not torch.all(tensor1[index:] == 0):
error_msg += "Padding was not done with the right value (0)."
tensor2 = accelerator.pad_across_processes(tensor.clone(), pad_first=True)
if tensor2.shape[0] != accelerator.state.num_processes + 1:
error_msg += f"Found shape {tensor2.shape} but should have {accelerator.state.num_processes + 1} at dim 0."
index = accelerator.state.num_processes - accelerator.state.process_index - 1
if not torch.equal(tensor2[index:], tensor):
error_msg += "Tensors have different values."
if not torch.all(tensor2[:index] == 0):
error_msg += "Padding was not done with the right value (0)."
# Raise error at the end to make sure we don't stop at the first failure.
if len(error_msg) > 0:
raise ValueError(error_msg)
# Check device_map
accelerator.print("Test `device_map` cannot be prepared.")
class ModelForTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 4)
self.batchnorm = torch.nn.BatchNorm1d(4)
self.linear2 = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))
if is_hpu_available():
device_map = {"linear1": 0, "batchnorm": "cpu", "linear2": 0}
else:
device_map = {"linear1": 0, "batchnorm": "cpu", "linear2": 1}
model = ModelForTest()
dispatch_model(model, device_map=device_map)
with assert_exception(ValueError, "You can't train a model that has been loaded with"):
model = accelerator.prepare_model(model)