mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
ENH Support XPU in train_memory.py script (#2729)
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
This commit is contained in:
@ -70,17 +70,18 @@ from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
|
|||||||
# suppress all warnings
|
# suppress all warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
||||||
dtype_to_bytes_linear = {"float32": 4, "float16": 2, "bfloat16": 2, "int8": 1, "int4": 0.5}
|
dtype_to_bytes_linear = {"float32": 4, "float16": 2, "bfloat16": 2, "int8": 1, "int4": 0.5}
|
||||||
|
|
||||||
|
|
||||||
def init_cuda():
|
def init_accelerator():
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
return
|
return
|
||||||
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
device_module = getattr(torch, device, torch.cuda)
|
||||||
torch.cuda.manual_seed_all(0)
|
device_module.reset_peak_memory_stats()
|
||||||
|
device_module.manual_seed_all(0)
|
||||||
# might not be necessary, but just to be sure
|
# might not be necessary, but just to be sure
|
||||||
nn.Linear(1, 1).to(device)
|
nn.Linear(1, 1).to(device)
|
||||||
|
|
||||||
@ -106,9 +107,10 @@ def get_data(tokenizer):
|
|||||||
|
|
||||||
|
|
||||||
def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, max_steps, path_config):
|
def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, max_steps, path_config):
|
||||||
init_cuda()
|
init_accelerator()
|
||||||
cuda_memory_init = torch.cuda.max_memory_allocated()
|
device_module = getattr(torch, device, torch.cuda)
|
||||||
cuda_memory_log = []
|
accelerator_memory_init = device_module.max_memory_allocated()
|
||||||
|
accelerator_memory_log = []
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
tokenizer.model_max_length = max_seq_length
|
tokenizer.model_max_length = max_seq_length
|
||||||
@ -177,8 +179,8 @@ def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, ma
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
cuda_memory_log.append(torch.cuda.memory_allocated() - cuda_memory_init)
|
accelerator_memory_log.append(device_module.memory_allocated() - accelerator_memory_init)
|
||||||
torch.cuda.empty_cache()
|
device_module.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
toc = time.perf_counter()
|
toc = time.perf_counter()
|
||||||
print(f"step {i:3d} loss {loss.item():.6f} time {toc - tic:.2f}s", file=sys.stderr)
|
print(f"step {i:3d} loss {loss.item():.6f} time {toc - tic:.2f}s", file=sys.stderr)
|
||||||
@ -191,10 +193,10 @@ def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, ma
|
|||||||
|
|
||||||
toc_total = time.perf_counter()
|
toc_total = time.perf_counter()
|
||||||
|
|
||||||
cuda_memory_final = torch.cuda.max_memory_allocated()
|
accelerator_memory_final = device_module.max_memory_allocated()
|
||||||
cuda_memory_avg = int(sum(cuda_memory_log) / len(cuda_memory_log))
|
accelerator_memory_avg = int(sum(accelerator_memory_log) / len(accelerator_memory_log))
|
||||||
print(f"cuda memory avg: {cuda_memory_avg // 2**20}MB")
|
print(f"{model.device.type} memory avg: {accelerator_memory_avg // 2**20}MB")
|
||||||
print(f"cuda memory max: {(cuda_memory_final - cuda_memory_init) // 2**20}MB")
|
print(f"{model.device.type} memory max: {(accelerator_memory_final - accelerator_memory_init) // 2**20}MB")
|
||||||
print(f"total time: {toc_total - tic_total:.2f}s")
|
print(f"total time: {toc_total - tic_total:.2f}s")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
Reference in New Issue
Block a user