mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
enable test_cli & test_example cases on XPU (#3578)
* enable test_cli & test_example cases on XPU Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * remove print Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix ci issue Signed-off-by: YAO Matrix <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
@ -21,10 +21,7 @@ from accelerate.test_utils import torch_device
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
|
||||
if torch_device == "hpu":
|
||||
synchronize_func = torch.hpu.synchronize
|
||||
else:
|
||||
synchronize_func = torch.cuda.synchronize
|
||||
synchronize_func = getattr(torch, torch_device, torch.cuda).synchronize
|
||||
|
||||
# Set the random seed to have reproducable outputs
|
||||
set_seed(42)
|
||||
|
@ -21,11 +21,7 @@ from accelerate.test_utils import torch_device
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
|
||||
if torch_device == "hpu":
|
||||
synchronize_func = torch.hpu.synchronize
|
||||
else:
|
||||
synchronize_func = torch.cuda.synchronize
|
||||
|
||||
synchronize_func = getattr(torch, torch_device, torch.cuda).synchronize
|
||||
|
||||
# Set the random seed to have reproducable outputs
|
||||
set_seed(42)
|
||||
|
@ -13,13 +13,19 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from accelerate.utils import is_xpu_available
|
||||
|
||||
|
||||
def main():
|
||||
accelerator_type = "GPU"
|
||||
num_accelerators = 0
|
||||
if torch.cuda.is_available():
|
||||
num_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
num_gpus = 0
|
||||
print(f"Successfully ran on {num_gpus} GPUs")
|
||||
num_accelerators = torch.cuda.device_count()
|
||||
accelerator_type = "GPU"
|
||||
elif is_xpu_available():
|
||||
num_accelerators = torch.xpu.device_count()
|
||||
accelerator_type = "XPU"
|
||||
print(f"Successfully ran on {num_accelerators} {accelerator_type}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -183,13 +183,7 @@ def find_executable_batch_size(
|
||||
|
||||
|
||||
def get_xpu_available_memory(device_index: int):
|
||||
if is_ipex_available():
|
||||
ipex_version = version.parse(importlib.metadata.version("intel_extension_for_pytorch"))
|
||||
if compare_versions(ipex_version, ">=", "2.5"):
|
||||
from intel_extension_for_pytorch.xpu import mem_get_info
|
||||
|
||||
return mem_get_info(device_index)[0]
|
||||
elif version.parse(torch.__version__).release >= version.parse("2.6").release:
|
||||
if version.parse(torch.__version__).release >= version.parse("2.6").release:
|
||||
# torch.xpu.mem_get_info API is available starting from PyTorch 2.6
|
||||
# It further requires PyTorch built with the SYCL runtime which supports API
|
||||
# to query available device memory. If not available, exception will be
|
||||
@ -200,6 +194,12 @@ def get_xpu_available_memory(device_index: int):
|
||||
return torch.xpu.mem_get_info(device_index)[0]
|
||||
except Exception:
|
||||
pass
|
||||
elif is_ipex_available():
|
||||
ipex_version = version.parse(importlib.metadata.version("intel_extension_for_pytorch"))
|
||||
if compare_versions(ipex_version, ">=", "2.5"):
|
||||
from intel_extension_for_pytorch.xpu import mem_get_info
|
||||
|
||||
return mem_get_info(device_index)[0]
|
||||
|
||||
warnings.warn(
|
||||
"The XPU `mem_get_info` API is available in IPEX version >=2.5 or PyTorch >=2.6. The current returned available memory is incorrect. Please consider upgrading your IPEX or PyTorch version."
|
||||
|
@ -72,6 +72,8 @@ class AccelerateLauncherTester(unittest.TestCase):
|
||||
args = ["--monitor_interval", "0.1", str(self.test_file_path)]
|
||||
if torch.cuda.is_available() and (torch.cuda.device_count() > 1):
|
||||
args = ["--multi_gpu"] + args
|
||||
elif torch.xpu.is_available() and (torch.xpu.device_count() > 1):
|
||||
args = ["--multi_gpu"] + args
|
||||
args = self.parser.parse_args(["--monitor_interval", "0.1", str(self.test_file_path)])
|
||||
launch_command(args)
|
||||
|
||||
|
@ -28,10 +28,10 @@ from accelerate.test_utils.testing import (
|
||||
TempDirTestCase,
|
||||
get_launch_command,
|
||||
is_hpu_available,
|
||||
is_xpu_available,
|
||||
require_fp16,
|
||||
require_huggingface_suite,
|
||||
require_multi_device,
|
||||
require_non_xpu,
|
||||
require_pippy,
|
||||
require_schedulefree,
|
||||
require_trackers,
|
||||
@ -204,6 +204,8 @@ class FeatureExamplesTests(TempDirTestCase):
|
||||
num_processes = torch.hpu.device_count()
|
||||
elif torch.cuda.is_available():
|
||||
num_processes = torch.cuda.device_count()
|
||||
elif is_xpu_available():
|
||||
num_processes = torch.xpu.device_count()
|
||||
else:
|
||||
num_processes = 1
|
||||
|
||||
@ -291,14 +293,12 @@ class FeatureExamplesTests(TempDirTestCase):
|
||||
run_command(self.launch_args + testargs)
|
||||
|
||||
@require_pippy
|
||||
@require_non_xpu
|
||||
@require_multi_device
|
||||
def test_pippy_examples_bert(self):
|
||||
testargs = ["examples/inference/pippy/bert.py"]
|
||||
run_command(self.launch_args + testargs)
|
||||
|
||||
@require_pippy
|
||||
@require_non_xpu
|
||||
@require_multi_device
|
||||
def test_pippy_examples_gpt2(self):
|
||||
testargs = ["examples/inference/pippy/gpt2.py"]
|
||||
|
Reference in New Issue
Block a user