mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
Feat: working examples
This commit is contained in:
@ -59,7 +59,10 @@ def forward(model, batch, optimizer, accelerator):
|
||||
loss = outputs.loss
|
||||
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
|
||||
with implicit_replication():
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
|
||||
return loss
|
||||
@ -69,10 +72,8 @@ def main():
|
||||
set_seed(42)
|
||||
args = parse_args()
|
||||
|
||||
if args.dp_replicate_size == 1:
|
||||
warnings.warn(
|
||||
"Accelerator.save_state() is not yet supported with pure tensor parallel training."
|
||||
)
|
||||
if args.dp_shard_size == 1:
|
||||
warnings.warn("Accelerator.save_state() is not yet supported with pure tensor parallel training.")
|
||||
|
||||
parallelism_config = ParallelismConfig(
|
||||
dp_replicate_size=args.dp_replicate_size,
|
||||
|
@ -749,6 +749,18 @@ class Accelerator:
|
||||
def torch_device_mesh(self):
|
||||
return self.state.device_mesh
|
||||
|
||||
@property
|
||||
def should_save_model(self):
|
||||
if (pc := self.parallelism_config) is None:
|
||||
# shouldn't even happen
|
||||
return self.state.is_local_main_process
|
||||
non_model_shard_dims = {
|
||||
pc.dp_enabled: "dp_replicate",
|
||||
pc.cp_enabled: "cp",
|
||||
}
|
||||
|
||||
return all(self.torch_device_mesh[dim].get_local_rank() == 0 for key, dim in non_model_shard_dims.items() if key)
|
||||
|
||||
def _setup_parallelism_config(
|
||||
self, parallelism_config: ParallelismConfig | None, torch_tp_plugin: TorchTensorParallelPlugin | None
|
||||
):
|
||||
|
@ -2985,6 +2985,7 @@ class ParallelismConfig:
|
||||
def non_data_parallel_size(self):
|
||||
return self.tp_size * self.cp_size
|
||||
|
||||
@property
|
||||
def data_parallel_size(self):
|
||||
return self.dp_replicate_size * self.dp_shard_size
|
||||
|
||||
|
Reference in New Issue
Block a user