[XPU] call empty_cache for dynamo tests (#126377)

When running a batch of models, lacking `empty_cache()` would result in OOM for subsequent models.

This PR unifies the `empty_cache` call for both CUDA and XPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126377
Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire
This commit is contained in:
Stonepia
2024-05-17 06:05:51 +00:00
committed by PyTorch MergeBot
parent 9edf54df4d
commit 5756b53dd8

View File

@ -354,6 +354,24 @@ def patch_torch_manual_seed():
torch.manual_seed = deterministic_torch_manual_seed
def empty_gpu_cache(device):
"""
Explicitly empty gpu cache to avoid OOM in subsequent run.
"""
if device not in ["cuda", "xpu"]:
log.warning(
"Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]",
device,
)
return
if device == "cuda":
torch.cuda.empty_cache()
elif device == "xpu":
torch.xpu.empty_cache()
def synchronize():
pass
@ -2278,7 +2296,7 @@ class BenchmarkRunner:
def batch_size_finder(self, device, model_name, initial_batch_size=1024):
batch_size = initial_batch_size
while batch_size >= 1:
torch.cuda.empty_cache()
empty_gpu_cache(current_device)
try:
device, name, model, example_inputs, _ = self.load_model(
device,
@ -2468,7 +2486,7 @@ class BenchmarkRunner:
fp64_outputs = None
finally:
del model_fp64, inputs_fp64
torch.cuda.empty_cache()
empty_gpu_cache(current_device)
tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
self.args.training, current_device, name
@ -2497,7 +2515,7 @@ class BenchmarkRunner:
return record_status(accuracy_status, dynamo_start_stats=start_stats)
finally:
del model_copy
torch.cuda.empty_cache()
empty_gpu_cache(current_device)
# Rerun native pytorch
reset_rng_state()
@ -2518,7 +2536,7 @@ class BenchmarkRunner:
return record_status(accuracy_status, dynamo_start_stats=start_stats)
finally:
del model_copy
torch.cuda.empty_cache()
empty_gpu_cache(current_device)
# Two eager runs should have exactly same result
is_same = True
@ -2719,7 +2737,7 @@ class BenchmarkRunner:
try:
if current_device == "cuda":
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
empty_gpu_cache(current_device)
t0 = time.perf_counter()
for _ in range(niters):
fn(model, example_inputs)
@ -2949,7 +2967,7 @@ class BenchmarkRunner:
name, model, example_inputs, optimize_ctx, experiment, tag
)
print(status)
torch.cuda.empty_cache()
empty_gpu_cache(current_device)
self.maybe_preserve_compile_debug(name, status)