mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Make grad accum steps mutable on the Accelerator object (#1233)
* Make grad accum steps mutable * Reset state
This commit is contained in:
@ -819,10 +819,18 @@ class Accelerator:
|
||||
def sync_gradients(self):
|
||||
return self.gradient_state.sync_gradients
|
||||
|
||||
@sync_gradients.setter
|
||||
def sync_gradients(self, sync_gradients):
|
||||
self.gradient_state.sync_gradients = sync_gradients
|
||||
|
||||
@property
|
||||
def gradient_accumulation_steps(self):
|
||||
return self.gradient_state.num_steps
|
||||
|
||||
@gradient_accumulation_steps.setter
|
||||
def gradient_accumulation_steps(self, gradient_accumulation_steps):
|
||||
self.gradient_state.plugin_kwargs.update({"num_steps": gradient_accumulation_steps})
|
||||
|
||||
@contextmanager
|
||||
def accumulate(self, model):
|
||||
"""
|
||||
|
@ -9,7 +9,7 @@ from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
from accelerate.accelerator import Accelerator
|
||||
from accelerate.state import PartialState
|
||||
from accelerate.state import GradientState, PartialState
|
||||
from accelerate.test_utils import require_multi_gpu, slow
|
||||
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda
|
||||
from accelerate.utils import patch_environment
|
||||
@ -43,6 +43,18 @@ class AcceleratorTester(AccelerateTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = Accelerator(cpu=True)
|
||||
|
||||
def test_mutable_states(self):
|
||||
accelerator = Accelerator()
|
||||
state = GradientState()
|
||||
assert state.num_steps == 1
|
||||
accelerator.gradient_accumulation_steps = 4
|
||||
assert state.num_steps == 4
|
||||
|
||||
assert state.sync_gradients is True
|
||||
accelerator.sync_gradients = False
|
||||
assert state.sync_gradients is False
|
||||
GradientState._reset_state()
|
||||
|
||||
def test_prepared_objects_are_referenced(self):
|
||||
accelerator = Accelerator()
|
||||
model, optimizer, scheduler, train_dl, valid_dl = create_components()
|
||||
|
Reference in New Issue
Block a user