Enable dynamo-traced optimizer peak memory tests (#124543)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124543
Approved by: https://github.com/yf225, https://github.com/janeyx99
This commit is contained in:
Michael Lazos
2024-05-06 18:56:18 -07:00
committed by PyTorch MergeBot
parent 5033d3ba6d
commit f0c6d6100b
2 changed files with 21 additions and 82 deletions

View File

@ -22,13 +22,12 @@ from torch.testing._internal.common_optimizers import (
optim_db, optims, OptimizerErrorEnum, _get_optim_inputs_including_global_cliquey_kwargs, TensorTracker)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, largeTensorTest, onlyCPU, onlyCUDA, skipMPS, TEST_WITH_ROCM, onlyNativeDeviceTypes)
from torch.testing._internal.common_utils import markDynamoStrictTest, parametrize, run_tests, TestCase
from torch.testing._internal.common_utils import markDynamoStrictTest, parametrize, run_tests, TestCase, TEST_WITH_TORCHDYNAMO
from torch.testing._internal.common_cuda import _create_scaling_case
from torch.testing._internal.common_dtype import floating_types_and
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
def rosenbrock(tensor):
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
@ -826,7 +825,9 @@ class TestOptimRenewed(TestCase):
st_max_mem, mt_max_mem = max_mems
intermediate_size = nparams * param.nelement() * param.element_size()
nintermediates = 1 # we expect a budget of 1 intermediate most of the time
if kwargs.get('capturable') or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
# Check the param group directly to handle if the compiler set capturable
if optimizer.param_groups[0].get("capturable", False) or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
# with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
# with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
# ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
@ -834,12 +835,22 @@ class TestOptimRenewed(TestCase):
if optim_cls.__name__ == "NAdam":
# with capturable in NAdam, we have 3 extra intermediates for the
# bias_correction, mus, and mu_nexts
nintermediates = 5
if TEST_WITH_TORCHDYNAMO:
# With dynamo, the eager/FX backend appears to hold memory longer than
# vanilla eager: https://github.com/pytorch/pytorch/issues/125511
nintermediates = 8
else:
nintermediates = 5
if optim_cls.__name__ == "RAdam":
# RAdam has four intermediates with capturable
# num, unrect_step_size, buffer, grouped_grads
nintermediates = 4
if TEST_WITH_TORCHDYNAMO:
# With dynamo, the eager/FX backend appears to hold memory than
# vanilla eager: https://github.com/pytorch/pytorch/issues/125511
nintermediates = 6
else:
nintermediates = 4
elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop"]:
# NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
@ -847,6 +858,11 @@ class TestOptimRenewed(TestCase):
# RMSprop uses avg and grads
nintermediates = 2
# Dynamo ST uses less mem than eager in the case of Adam/Adagrad/Nadam/RAdam
# which makes the foreach memory check fail
if TEST_WITH_TORCHDYNAMO:
st_max_mem += 6000
expected_max_mem = st_max_mem + intermediate_size * nintermediates
# hipcc currently can't generate efficient code for the small buffer optimization
# code path (see Note [small buffer optimization] for details), thus we always

View File

@ -1103,13 +1103,6 @@ optim_db: List[OptimizerInfo] = [
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo("See #116028"),
"TestOptimRenewed",
@ -1175,13 +1168,6 @@ optim_db: List[OptimizerInfo] = [
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo("See #116028"),
"TestOptimRenewed",
@ -1277,13 +1263,6 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
@ -1328,13 +1307,6 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
@ -1407,13 +1379,6 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
@ -1442,13 +1407,6 @@ optim_db: List[OptimizerInfo] = [
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
@ -1547,13 +1505,6 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
@ -1604,13 +1555,6 @@ optim_db: List[OptimizerInfo] = [
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
@ -1662,13 +1606,6 @@ optim_db: List[OptimizerInfo] = [
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo("See #116028"),
"TestOptimRenewed",
@ -1718,13 +1655,6 @@ optim_db: List[OptimizerInfo] = [
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo("See #116028"),
"TestOptimRenewed",
@ -1803,13 +1733,6 @@ optim_db: List[OptimizerInfo] = [
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"