mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
switch XPU ccl backend to torch-builtin xccl in test_zero3_integration (#3773)
* switch XPU ccl backend to torch-builtin xccl in test_zero3_integration remove xpu workaround in RegressionModel, we are OK now rename test_multigpu to test_multidevice to reflect the fact Signed-off-by: Yao, Matrix <matrix.yao@intel.com> * fix ci issues Signed-off-by: Yao, Matrix <matrix.yao@intel.com> * xx Signed-off-by: Yao, Matrix <matrix.yao@intel.com> --------- Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
This commit is contained in:
@ -60,7 +60,7 @@ from .testing import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from .training import RegressionDataset, RegressionModel, RegressionModel4XPU
|
from .training import RegressionDataset, RegressionModel
|
||||||
|
|
||||||
|
|
||||||
from .scripts import test_script, test_sync, test_ops # isort: skip
|
from .scripts import test_script, test_sync, test_ops # isort: skip
|
||||||
|
@ -28,7 +28,7 @@ GPT2_TINY = "sshleifer/tiny-gpt2"
|
|||||||
@require_huggingface_suite
|
@require_huggingface_suite
|
||||||
def init_torch_dist_then_launch_deepspeed():
|
def init_torch_dist_then_launch_deepspeed():
|
||||||
if torch_device == "xpu":
|
if torch_device == "xpu":
|
||||||
backend = "ccl"
|
backend = "xccl"
|
||||||
elif torch_device == "hpu":
|
elif torch_device == "hpu":
|
||||||
backend = "hccl"
|
backend = "hccl"
|
||||||
else:
|
else:
|
||||||
|
@ -28,7 +28,7 @@ from torch.utils.data import DataLoader, Dataset
|
|||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
|
from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader
|
||||||
from accelerate.state import AcceleratorState
|
from accelerate.state import AcceleratorState
|
||||||
from accelerate.test_utils import RegressionDataset, are_the_same_tensors
|
from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
DataLoaderConfiguration,
|
DataLoaderConfiguration,
|
||||||
DistributedType,
|
DistributedType,
|
||||||
@ -42,18 +42,11 @@ from accelerate.utils import (
|
|||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_mps_available,
|
is_mps_available,
|
||||||
is_pytest_available,
|
is_pytest_available,
|
||||||
is_xpu_available,
|
|
||||||
set_seed,
|
set_seed,
|
||||||
synchronize_rng_states,
|
synchronize_rng_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: remove RegressionModel4XPU once ccl support empty buffer in broadcasting.
|
|
||||||
if is_xpu_available():
|
|
||||||
from accelerate.test_utils import RegressionModel4XPU as RegressionModel
|
|
||||||
else:
|
|
||||||
from accelerate.test_utils import RegressionModel
|
|
||||||
|
|
||||||
if is_hpu_available():
|
if is_hpu_available():
|
||||||
ATOL = 1e-3
|
ATOL = 1e-3
|
||||||
RTOL = 1e-3
|
RTOL = 1e-3
|
||||||
|
@ -33,20 +33,6 @@ class RegressionDataset:
|
|||||||
return {"x": self.x[i], "y": self.y[i]}
|
return {"x": self.x[i], "y": self.y[i]}
|
||||||
|
|
||||||
|
|
||||||
class RegressionModel4XPU(torch.nn.Module):
|
|
||||||
def __init__(self, a=0, b=0, double_output=False):
|
|
||||||
super().__init__()
|
|
||||||
self.a = torch.nn.Parameter(torch.tensor([2, 3]).float())
|
|
||||||
self.b = torch.nn.Parameter(torch.tensor([2, 3]).float())
|
|
||||||
self.first_batch = True
|
|
||||||
|
|
||||||
def forward(self, x=None):
|
|
||||||
if self.first_batch:
|
|
||||||
print(f"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}")
|
|
||||||
self.first_batch = False
|
|
||||||
return x * self.a[0] + self.b[0]
|
|
||||||
|
|
||||||
|
|
||||||
class RegressionModel(torch.nn.Module):
|
class RegressionModel(torch.nn.Module):
|
||||||
def __init__(self, a=0, b=0, double_output=False):
|
def __init__(self, a=0, b=0, double_output=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
Reference in New Issue
Block a user