make torch_native_parallelism examples device agnostic (#3759)

* make torch_native_parallelism examples device agnostic

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* xxx

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* xxx

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Style + deprecation warning

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
This commit is contained in:
Yao Matrix
2025-09-08 03:16:56 -07:00
committed by GitHub
parent ec92b1af7a
commit 40ebb4bea3
2 changed files with 13 additions and 7 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from accelerate.utils import ParallelismConfig
@ -28,7 +29,7 @@ def parse_args():
parser.add_argument("--checkpoint-frequency", type=int, default=100)
parser.add_argument("--model-name", type=str, default=MODEL_ID)
parser.add_argument("--save-dir", type=str, default=f"./accelerate-nd-parallel-{MODEL_ID.split('/')[-1]}")
parser.add_argument("--device-type", type=str, default="cuda")
parser.add_argument("--device-type", type=str, default="auto")
return parser.parse_args()
@ -38,6 +39,9 @@ def main():
pc = ParallelismConfig()
args = parse_args()
if args.device_type == "auto":
args.device_type = torch.accelerator.current_accelerator().type
model_kwargs = {}
if pc.tp_enabled:
model_kwargs["tp_plan"] = "auto"
@ -66,7 +70,7 @@ def main():
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,
processing_class=tokenizer,
train_dataset=packed_dataset,
)

View File

@ -202,16 +202,18 @@ def setup_tokenizer(model_id: str) -> AutoTokenizer:
def gpu_memory_usage_all(device=0):
device = torch.device(f"cuda:{device}")
device_type = torch.accelerator.current_accelerator().type
device = torch.device(f"{device_type}:{device}")
torch_device_module = getattr(torch, device_type, torch.cuda)
_BYTES_IN_GIB = 1024**3
peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
peak_memory_alloc = torch.cuda.max_memory_allocated(device) / _BYTES_IN_GIB
peak_memory_reserved = torch.cuda.max_memory_reserved(device) / _BYTES_IN_GIB
peak_memory_active = torch_device_module.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
peak_memory_alloc = torch_device_module.max_memory_allocated(device) / _BYTES_IN_GIB
peak_memory_reserved = torch_device_module.max_memory_reserved(device) / _BYTES_IN_GIB
memory_stats = {
"peak_memory_active": peak_memory_active,
"peak_memory_alloc": peak_memory_alloc,
"peak_memory_reserved": peak_memory_reserved,
}
torch.cuda.reset_peak_memory_stats(device)
torch_device_module.reset_peak_memory_stats(device)
return memory_stats