mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[XPU] call empty_cache for dynamo tests (#126377)
When running a batch of models, lacking `empty_cache()` would result in OOM for subsequent models. This PR unifies the `empty_cache` call for both CUDA and XPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126377 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
9edf54df4d
commit
5756b53dd8
@ -354,6 +354,24 @@ def patch_torch_manual_seed():
|
|||||||
torch.manual_seed = deterministic_torch_manual_seed
|
torch.manual_seed = deterministic_torch_manual_seed
|
||||||
|
|
||||||
|
|
||||||
|
def empty_gpu_cache(device):
|
||||||
|
"""
|
||||||
|
Explicitly empty gpu cache to avoid OOM in subsequent run.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if device not in ["cuda", "xpu"]:
|
||||||
|
log.warning(
|
||||||
|
"Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]",
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
elif device == "xpu":
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def synchronize():
|
def synchronize():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -2278,7 +2296,7 @@ class BenchmarkRunner:
|
|||||||
def batch_size_finder(self, device, model_name, initial_batch_size=1024):
|
def batch_size_finder(self, device, model_name, initial_batch_size=1024):
|
||||||
batch_size = initial_batch_size
|
batch_size = initial_batch_size
|
||||||
while batch_size >= 1:
|
while batch_size >= 1:
|
||||||
torch.cuda.empty_cache()
|
empty_gpu_cache(current_device)
|
||||||
try:
|
try:
|
||||||
device, name, model, example_inputs, _ = self.load_model(
|
device, name, model, example_inputs, _ = self.load_model(
|
||||||
device,
|
device,
|
||||||
@ -2468,7 +2486,7 @@ class BenchmarkRunner:
|
|||||||
fp64_outputs = None
|
fp64_outputs = None
|
||||||
finally:
|
finally:
|
||||||
del model_fp64, inputs_fp64
|
del model_fp64, inputs_fp64
|
||||||
torch.cuda.empty_cache()
|
empty_gpu_cache(current_device)
|
||||||
|
|
||||||
tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
|
tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
|
||||||
self.args.training, current_device, name
|
self.args.training, current_device, name
|
||||||
@ -2497,7 +2515,7 @@ class BenchmarkRunner:
|
|||||||
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
||||||
finally:
|
finally:
|
||||||
del model_copy
|
del model_copy
|
||||||
torch.cuda.empty_cache()
|
empty_gpu_cache(current_device)
|
||||||
|
|
||||||
# Rerun native pytorch
|
# Rerun native pytorch
|
||||||
reset_rng_state()
|
reset_rng_state()
|
||||||
@ -2518,7 +2536,7 @@ class BenchmarkRunner:
|
|||||||
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
||||||
finally:
|
finally:
|
||||||
del model_copy
|
del model_copy
|
||||||
torch.cuda.empty_cache()
|
empty_gpu_cache(current_device)
|
||||||
|
|
||||||
# Two eager runs should have exactly same result
|
# Two eager runs should have exactly same result
|
||||||
is_same = True
|
is_same = True
|
||||||
@ -2719,7 +2737,7 @@ class BenchmarkRunner:
|
|||||||
try:
|
try:
|
||||||
if current_device == "cuda":
|
if current_device == "cuda":
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
torch.cuda.empty_cache()
|
empty_gpu_cache(current_device)
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
for _ in range(niters):
|
for _ in range(niters):
|
||||||
fn(model, example_inputs)
|
fn(model, example_inputs)
|
||||||
@ -2949,7 +2967,7 @@ class BenchmarkRunner:
|
|||||||
name, model, example_inputs, optimize_ctx, experiment, tag
|
name, model, example_inputs, optimize_ctx, experiment, tag
|
||||||
)
|
)
|
||||||
print(status)
|
print(status)
|
||||||
torch.cuda.empty_cache()
|
empty_gpu_cache(current_device)
|
||||||
|
|
||||||
self.maybe_preserve_compile_debug(name, status)
|
self.maybe_preserve_compile_debug(name, status)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user