Make grad accum steps mutable on the Accelerator object (#1233)

* Make grad accum steps mutable

* Reset state
This commit is contained in:
Zachary Mueller
2023-03-22 17:44:31 -04:00
committed by GitHub
parent 6e4e870203
commit b1b3312749
2 changed files with 21 additions and 1 deletions

View File

@ -819,10 +819,18 @@ class Accelerator:
def sync_gradients(self):
return self.gradient_state.sync_gradients
@sync_gradients.setter
def sync_gradients(self, sync_gradients):
self.gradient_state.sync_gradients = sync_gradients
@property
def gradient_accumulation_steps(self):
return self.gradient_state.num_steps
@gradient_accumulation_steps.setter
def gradient_accumulation_steps(self, gradient_accumulation_steps):
self.gradient_state.plugin_kwargs.update({"num_steps": gradient_accumulation_steps})
@contextmanager
def accumulate(self, model):
"""

View File

@ -9,7 +9,7 @@ from torch.utils.data import DataLoader, TensorDataset
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.accelerator import Accelerator
from accelerate.state import PartialState
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import require_multi_gpu, slow
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda
from accelerate.utils import patch_environment
@ -43,6 +43,18 @@ class AcceleratorTester(AccelerateTestCase):
with self.assertRaises(ValueError):
_ = Accelerator(cpu=True)
def test_mutable_states(self):
accelerator = Accelerator()
state = GradientState()
assert state.num_steps == 1
accelerator.gradient_accumulation_steps = 4
assert state.num_steps == 4
assert state.sync_gradients is True
accelerator.sync_gradients = False
assert state.sync_gradients is False
GradientState._reset_state()
def test_prepared_objects_are_referenced(self):
accelerator = Accelerator()
model, optimizer, scheduler, train_dl, valid_dl = create_components()