Feat: working examples

This commit is contained in:
S1ro1
2025-07-28 11:11:46 +00:00
parent d21ff9f245
commit aafde25cfc
3 changed files with 19 additions and 5 deletions

View File

@ -59,7 +59,10 @@ def forward(model, batch, optimizer, accelerator):
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
from torch.distributed.tensor.experimental import implicit_replication
with implicit_replication():
optimizer.step()
optimizer.zero_grad()
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
return loss
@ -69,10 +72,8 @@ def main():
set_seed(42)
args = parse_args()
if args.dp_replicate_size == 1:
warnings.warn(
"Accelerator.save_state() is not yet supported with pure tensor parallel training."
)
if args.dp_shard_size == 1:
warnings.warn("Accelerator.save_state() is not yet supported with pure tensor parallel training.")
parallelism_config = ParallelismConfig(
dp_replicate_size=args.dp_replicate_size,

View File

@ -749,6 +749,18 @@ class Accelerator:
def torch_device_mesh(self):
return self.state.device_mesh
@property
def should_save_model(self):
if (pc := self.parallelism_config) is None:
# shouldn't even happen
return self.state.is_local_main_process
non_model_shard_dims = {
pc.dp_enabled: "dp_replicate",
pc.cp_enabled: "cp",
}
return all(self.torch_device_mesh[dim].get_local_rank() == 0 for key, dim in non_model_shard_dims.items() if key)
def _setup_parallelism_config(
self, parallelism_config: ParallelismConfig | None, torch_tp_plugin: TorchTensorParallelPlugin | None
):

View File

@ -2985,6 +2985,7 @@ class ParallelismConfig:
def non_data_parallel_size(self):
return self.tp_size * self.cp_size
@property
def data_parallel_size(self):
return self.dp_replicate_size * self.dp_shard_size