From 40ebb4bea35a1e0de22e68c29f5c446f86e58530 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Mon, 8 Sep 2025 03:16:56 -0700 Subject: [PATCH] make torch_native_parallelism examples device agnostic (#3759) * make torch_native_parallelism examples device agnostic Signed-off-by: YAO Matrix * xxx Signed-off-by: YAO Matrix * xxx Signed-off-by: YAO Matrix * Style + deprecation warning --------- Signed-off-by: YAO Matrix Co-authored-by: S1ro1 --- .../torch_native_parallelism/nd_parallel_trainer.py | 8 ++++++-- examples/torch_native_parallelism/utils.py | 12 +++++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/torch_native_parallelism/nd_parallel_trainer.py b/examples/torch_native_parallelism/nd_parallel_trainer.py index 7d158515..389584a4 100644 --- a/examples/torch_native_parallelism/nd_parallel_trainer.py +++ b/examples/torch_native_parallelism/nd_parallel_trainer.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse +import torch from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments from accelerate.utils import ParallelismConfig @@ -28,7 +29,7 @@ def parse_args(): parser.add_argument("--checkpoint-frequency", type=int, default=100) parser.add_argument("--model-name", type=str, default=MODEL_ID) parser.add_argument("--save-dir", type=str, default=f"./accelerate-nd-parallel-{MODEL_ID.split('/')[-1]}") - parser.add_argument("--device-type", type=str, default="cuda") + parser.add_argument("--device-type", type=str, default="auto") return parser.parse_args() @@ -38,6 +39,9 @@ def main(): pc = ParallelismConfig() args = parse_args() + if args.device_type == "auto": + args.device_type = torch.accelerator.current_accelerator().type + model_kwargs = {} if pc.tp_enabled: model_kwargs["tp_plan"] = "auto" @@ -66,7 +70,7 @@ def main(): trainer = Trainer( model=model, args=training_args, - tokenizer=tokenizer, + processing_class=tokenizer, train_dataset=packed_dataset, ) diff --git a/examples/torch_native_parallelism/utils.py b/examples/torch_native_parallelism/utils.py index 3dbe583a..ff55864f 100644 --- a/examples/torch_native_parallelism/utils.py +++ b/examples/torch_native_parallelism/utils.py @@ -202,16 +202,18 @@ def setup_tokenizer(model_id: str) -> AutoTokenizer: def gpu_memory_usage_all(device=0): - device = torch.device(f"cuda:{device}") + device_type = torch.accelerator.current_accelerator().type + device = torch.device(f"{device_type}:{device}") + torch_device_module = getattr(torch, device_type, torch.cuda) _BYTES_IN_GIB = 1024**3 - peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB - peak_memory_alloc = torch.cuda.max_memory_allocated(device) / _BYTES_IN_GIB - peak_memory_reserved = torch.cuda.max_memory_reserved(device) / _BYTES_IN_GIB + peak_memory_active = torch_device_module.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB + peak_memory_alloc = torch_device_module.max_memory_allocated(device) / _BYTES_IN_GIB + peak_memory_reserved = torch_device_module.max_memory_reserved(device) / _BYTES_IN_GIB memory_stats = { "peak_memory_active": peak_memory_active, "peak_memory_alloc": peak_memory_alloc, "peak_memory_reserved": peak_memory_reserved, } - torch.cuda.reset_peak_memory_stats(device) + torch_device_module.reset_peak_memory_stats(device) return memory_stats