mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
enable fsdp2 benchmark on XPU (#3590)
* enable fsdp2 benchmark on XPU Signed-off-by: Matrix YAO <matrix.yao@intel.com> * add deterministic Signed-off-by: Matrix YAO <matrix.yao@intel.com> --------- Signed-off-by: Matrix YAO <matrix.yao@intel.com>
This commit is contained in:
@ -92,6 +92,7 @@ def main():
|
||||
]
|
||||
|
||||
results = {}
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
for evaluation, label in zip(evaluations, labels):
|
||||
results[label] = evaluation(args, CONFIG)
|
||||
|
@ -37,7 +37,7 @@ class MemoryTracker:
|
||||
|
||||
Args:
|
||||
device (`torch.device`):
|
||||
Cuda device to monitor.
|
||||
PyTorch device to monitor.
|
||||
output_directory (`str`):
|
||||
Directory to save the memory usage data to, will be created if it doesn't exist.
|
||||
run_name (`str`):
|
||||
@ -63,14 +63,15 @@ class MemoryTracker:
|
||||
self._thread = None
|
||||
self._state = PartialState()
|
||||
self._process = psutil.Process()
|
||||
self._devicee = device
|
||||
self._device = device
|
||||
self.torch_accelerator_module = getattr(torch, device.type, torch.cuda)
|
||||
|
||||
def _monitor(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
while self.running:
|
||||
allocated = torch.cuda.memory_allocated(self._devicee) / (1024 * 1024)
|
||||
reserved = torch.cuda.memory_reserved(self._devicee) / (1024 * 1024)
|
||||
allocated = self.torch_accelerator_module.memory_allocated(self._device) / (1024 * 1024)
|
||||
reserved = self.torch_accelerator_module.memory_reserved(self._device) / (1024 * 1024)
|
||||
virtual_memory = self._process.memory_info().rss / (1024 * 1024)
|
||||
|
||||
self.allocated_memory.append(allocated)
|
||||
@ -82,13 +83,13 @@ class MemoryTracker:
|
||||
|
||||
def start(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.torch_accelerator_module.empty_cache()
|
||||
|
||||
if self.output_directory:
|
||||
os.makedirs(self.output_directory, exist_ok=True)
|
||||
|
||||
if self.save_memory_snapshot:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
self.torch_accelerator_module.memory._record_memory_history()
|
||||
|
||||
self.running = True
|
||||
self._thread = threading.Thread(target=self._monitor)
|
||||
@ -102,7 +103,7 @@ class MemoryTracker:
|
||||
|
||||
if self.save_memory_snapshot and self._state.is_main_process and self.output_directory:
|
||||
output_file = os.path.join(self.output_directory, f"{self.run_name}_memory_snapshot.pkl")
|
||||
torch.cuda.memory._dump_snapshot(output_file)
|
||||
self.torch_accelerator_module.memory._dump_snapshot(output_file)
|
||||
|
||||
if self._state.is_main_process and self.output_directory:
|
||||
path = os.path.join(self.output_directory, f"{self.run_name}_memory_usage.json")
|
||||
@ -116,9 +117,9 @@ class MemoryTracker:
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
torch.cuda.memory._record_memory_history(False)
|
||||
torch.cuda.empty_cache()
|
||||
if self.save_memory_snapshot:
|
||||
self.torch_accelerator_module.memory._record_memory_history(False)
|
||||
self.torch_accelerator_module.empty_cache()
|
||||
|
||||
@property
|
||||
def peak_allocated_memory(self):
|
||||
|
@ -219,7 +219,7 @@ def prepare_torch(
|
||||
tokenizer = get_tokenizer(config["model_name"])
|
||||
train_dataloader = prepare_dataloader(tokenizer, args, accelerator)
|
||||
|
||||
memory_tracker = MemoryTracker(accelerator, args.output_dir, run_name, args.save_memory_snapshot)
|
||||
memory_tracker = MemoryTracker(accelerator.device, args.output_dir, run_name, args.save_memory_snapshot)
|
||||
memory_tracker.start()
|
||||
|
||||
model = get_model(config["model_name"])
|
||||
@ -279,7 +279,7 @@ def prepare_accelerate(
|
||||
tokenizer = get_tokenizer(config["model_name"])
|
||||
train_dataloader = prepare_dataloader(tokenizer, args, accelerator)
|
||||
|
||||
memory_tracker = MemoryTracker(accelerator, args.output_dir, "accelerate", args.save_memory_snapshot)
|
||||
memory_tracker = MemoryTracker(accelerator.device, args.output_dir, "accelerate", args.save_memory_snapshot)
|
||||
memory_tracker.start()
|
||||
|
||||
model = get_model(config["model_name"])
|
||||
|
Reference in New Issue
Block a user