mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable dynamo-traced optimizer peak memory tests (#124543)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124543 Approved by: https://github.com/yf225, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
5033d3ba6d
commit
f0c6d6100b
@ -22,13 +22,12 @@ from torch.testing._internal.common_optimizers import (
|
||||
optim_db, optims, OptimizerErrorEnum, _get_optim_inputs_including_global_cliquey_kwargs, TensorTracker)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM, onlyNativeDeviceTypes)
|
||||
from torch.testing._internal.common_utils import markDynamoStrictTest, parametrize, run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import markDynamoStrictTest, parametrize, run_tests, TestCase, TEST_WITH_TORCHDYNAMO
|
||||
from torch.testing._internal.common_cuda import _create_scaling_case
|
||||
from torch.testing._internal.common_dtype import floating_types_and
|
||||
|
||||
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
|
||||
|
||||
|
||||
def rosenbrock(tensor):
|
||||
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
|
||||
x, y = tensor
|
||||
@ -826,7 +825,9 @@ class TestOptimRenewed(TestCase):
|
||||
st_max_mem, mt_max_mem = max_mems
|
||||
intermediate_size = nparams * param.nelement() * param.element_size()
|
||||
nintermediates = 1 # we expect a budget of 1 intermediate most of the time
|
||||
if kwargs.get('capturable') or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
|
||||
|
||||
# Check the param group directly to handle if the compiler set capturable
|
||||
if optimizer.param_groups[0].get("capturable", False) or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
|
||||
# with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
|
||||
# with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
|
||||
# ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
|
||||
@ -834,12 +835,22 @@ class TestOptimRenewed(TestCase):
|
||||
if optim_cls.__name__ == "NAdam":
|
||||
# with capturable in NAdam, we have 3 extra intermediates for the
|
||||
# bias_correction, mus, and mu_nexts
|
||||
nintermediates = 5
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
# With dynamo, the eager/FX backend appears to hold memory longer than
|
||||
# vanilla eager: https://github.com/pytorch/pytorch/issues/125511
|
||||
nintermediates = 8
|
||||
else:
|
||||
nintermediates = 5
|
||||
|
||||
if optim_cls.__name__ == "RAdam":
|
||||
# RAdam has four intermediates with capturable
|
||||
# num, unrect_step_size, buffer, grouped_grads
|
||||
nintermediates = 4
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
# With dynamo, the eager/FX backend appears to hold memory than
|
||||
# vanilla eager: https://github.com/pytorch/pytorch/issues/125511
|
||||
nintermediates = 6
|
||||
else:
|
||||
nintermediates = 4
|
||||
|
||||
elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop"]:
|
||||
# NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
|
||||
@ -847,6 +858,11 @@ class TestOptimRenewed(TestCase):
|
||||
# RMSprop uses avg and grads
|
||||
nintermediates = 2
|
||||
|
||||
# Dynamo ST uses less mem than eager in the case of Adam/Adagrad/Nadam/RAdam
|
||||
# which makes the foreach memory check fail
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
st_max_mem += 6000
|
||||
|
||||
expected_max_mem = st_max_mem + intermediate_size * nintermediates
|
||||
# hipcc currently can't generate efficient code for the small buffer optimization
|
||||
# code path (see Note [small buffer optimization] for details), thus we always
|
||||
|
@ -1103,13 +1103,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("See #116028"),
|
||||
"TestOptimRenewed",
|
||||
@ -1175,13 +1168,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("See #116028"),
|
||||
"TestOptimRenewed",
|
||||
@ -1277,13 +1263,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_set_default_dtype_works_with_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||||
@ -1328,13 +1307,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_set_default_dtype_works_with_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||||
@ -1407,13 +1379,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_set_default_dtype_works_with_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||||
@ -1442,13 +1407,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||||
@ -1547,13 +1505,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_set_default_dtype_works_with_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||||
@ -1604,13 +1555,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||||
@ -1662,13 +1606,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("See #116028"),
|
||||
"TestOptimRenewed",
|
||||
@ -1718,13 +1655,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("See #116028"),
|
||||
"TestOptimRenewed",
|
||||
@ -1803,13 +1733,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_peak_memory_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||||
|
Reference in New Issue
Block a user