mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[XPU] call empty_cache for dynamo tests (#126377)
When running a batch of models, lacking `empty_cache()` would result in OOM for subsequent models. This PR unifies the `empty_cache` call for both CUDA and XPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126377 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
9edf54df4d
commit
5756b53dd8
@ -354,6 +354,24 @@ def patch_torch_manual_seed():
|
||||
torch.manual_seed = deterministic_torch_manual_seed
|
||||
|
||||
|
||||
def empty_gpu_cache(device):
|
||||
"""
|
||||
Explicitly empty gpu cache to avoid OOM in subsequent run.
|
||||
"""
|
||||
|
||||
if device not in ["cuda", "xpu"]:
|
||||
log.warning(
|
||||
"Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]",
|
||||
device,
|
||||
)
|
||||
return
|
||||
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
elif device == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
|
||||
def synchronize():
|
||||
pass
|
||||
|
||||
@ -2278,7 +2296,7 @@ class BenchmarkRunner:
|
||||
def batch_size_finder(self, device, model_name, initial_batch_size=1024):
|
||||
batch_size = initial_batch_size
|
||||
while batch_size >= 1:
|
||||
torch.cuda.empty_cache()
|
||||
empty_gpu_cache(current_device)
|
||||
try:
|
||||
device, name, model, example_inputs, _ = self.load_model(
|
||||
device,
|
||||
@ -2468,7 +2486,7 @@ class BenchmarkRunner:
|
||||
fp64_outputs = None
|
||||
finally:
|
||||
del model_fp64, inputs_fp64
|
||||
torch.cuda.empty_cache()
|
||||
empty_gpu_cache(current_device)
|
||||
|
||||
tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
|
||||
self.args.training, current_device, name
|
||||
@ -2497,7 +2515,7 @@ class BenchmarkRunner:
|
||||
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
||||
finally:
|
||||
del model_copy
|
||||
torch.cuda.empty_cache()
|
||||
empty_gpu_cache(current_device)
|
||||
|
||||
# Rerun native pytorch
|
||||
reset_rng_state()
|
||||
@ -2518,7 +2536,7 @@ class BenchmarkRunner:
|
||||
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
||||
finally:
|
||||
del model_copy
|
||||
torch.cuda.empty_cache()
|
||||
empty_gpu_cache(current_device)
|
||||
|
||||
# Two eager runs should have exactly same result
|
||||
is_same = True
|
||||
@ -2719,7 +2737,7 @@ class BenchmarkRunner:
|
||||
try:
|
||||
if current_device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
empty_gpu_cache(current_device)
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(niters):
|
||||
fn(model, example_inputs)
|
||||
@ -2949,7 +2967,7 @@ class BenchmarkRunner:
|
||||
name, model, example_inputs, optimize_ctx, experiment, tag
|
||||
)
|
||||
print(status)
|
||||
torch.cuda.empty_cache()
|
||||
empty_gpu_cache(current_device)
|
||||
|
||||
self.maybe_preserve_compile_debug(name, status)
|
||||
|
||||
|
Reference in New Issue
Block a user