ENH Support XPU in train_memory.py script (#2729)

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-08-08 03:06:46 -07:00
committed by GitHub
parent e98a59ec2d
commit a4b41e7924

View File

@ -70,17 +70,18 @@ from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
# suppress all warnings
warnings.filterwarnings("ignore")
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
dtype_to_bytes_linear = {"float32": 4, "float16": 2, "bfloat16": 2, "int8": 1, "int4": 0.5}
def init_cuda():
def init_accelerator():
torch.manual_seed(0)
if device == "cpu":
return
torch.cuda.reset_peak_memory_stats()
torch.cuda.manual_seed_all(0)
device_module = getattr(torch, device, torch.cuda)
device_module.reset_peak_memory_stats()
device_module.manual_seed_all(0)
# might not be necessary, but just to be sure
nn.Linear(1, 1).to(device)
@ -106,9 +107,10 @@ def get_data(tokenizer):
def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, max_steps, path_config):
init_cuda()
cuda_memory_init = torch.cuda.max_memory_allocated()
cuda_memory_log = []
init_accelerator()
device_module = getattr(torch, device, torch.cuda)
accelerator_memory_init = device_module.max_memory_allocated()
accelerator_memory_log = []
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.model_max_length = max_seq_length
@ -177,8 +179,8 @@ def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, ma
loss.backward()
optimizer.step()
losses.append(loss.item())
cuda_memory_log.append(torch.cuda.memory_allocated() - cuda_memory_init)
torch.cuda.empty_cache()
accelerator_memory_log.append(device_module.memory_allocated() - accelerator_memory_init)
device_module.empty_cache()
gc.collect()
toc = time.perf_counter()
print(f"step {i:3d} loss {loss.item():.6f} time {toc - tic:.2f}s", file=sys.stderr)
@ -191,10 +193,10 @@ def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, ma
toc_total = time.perf_counter()
cuda_memory_final = torch.cuda.max_memory_allocated()
cuda_memory_avg = int(sum(cuda_memory_log) / len(cuda_memory_log))
print(f"cuda memory avg: {cuda_memory_avg // 2**20}MB")
print(f"cuda memory max: {(cuda_memory_final - cuda_memory_init) // 2**20}MB")
accelerator_memory_final = device_module.max_memory_allocated()
accelerator_memory_avg = int(sum(accelerator_memory_log) / len(accelerator_memory_log))
print(f"{model.device.type} memory avg: {accelerator_memory_avg // 2**20}MB")
print(f"{model.device.type} memory max: {(accelerator_memory_final - accelerator_memory_init) // 2**20}MB")
print(f"total time: {toc_total - tic_total:.2f}s")
with tempfile.TemporaryDirectory() as tmp_dir: