mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Bug fixes to DDP _update_process_group API. (#114194)
https://github.com/pytorch/pytorch/pull/113580 introduced the `DDP._update_process_group` API. However, the implementation did not correctly reset all of the necessary state in the reducer. In particular if an error occurred during backward, DDP would end up in an incorrect state. As a result, in this PR I've enhanced the unit test to test for this case and also appropriately fixed resetting Reducer state. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114194 Approved by: https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
7c98bac4a0
commit
f505d76462
@ -66,7 +66,7 @@ class Reducer:
|
||||
def _remove_autograd_hooks(self) -> None: ...
|
||||
def _check_reducer_finalized(self) -> None: ...
|
||||
def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ...
|
||||
def _force_bucket_rebuild(self) -> None: ...
|
||||
def _reset_state(self) -> None: ...
|
||||
def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
|
||||
|
||||
class DDPLoggingData:
|
||||
|
@ -598,10 +598,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||
[](::c10d::Reducer& reducer) { return reducer.check_finalized(); },
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_force_bucket_rebuild",
|
||||
[](::c10d::Reducer& reducer) {
|
||||
return reducer.force_bucket_rebuild();
|
||||
},
|
||||
"_reset_state",
|
||||
[](::c10d::Reducer& reducer) { return reducer.reset_state(); },
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_update_process_group",
|
||||
|
@ -2314,11 +2314,20 @@ void Reducer::update_process_group(
|
||||
process_group_ = std::move(new_process_group);
|
||||
}
|
||||
|
||||
void Reducer::force_bucket_rebuild() {
|
||||
void Reducer::reset_state() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
// Force rebuild of buckets.
|
||||
has_rebuilt_bucket_ = false;
|
||||
rebuilt_params_.clear();
|
||||
rebuilt_param_indices_.clear();
|
||||
|
||||
// Ensure forward can run despite previous backward not succeeding.
|
||||
expect_autograd_hooks_ = false;
|
||||
require_finalize_ = false;
|
||||
|
||||
// Unset allreduce division factor, as it may change in next backwards pass
|
||||
// when running with DDP join mode.
|
||||
div_factor_ = kUnsetDivFactor;
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -194,8 +194,8 @@ class TORCH_API Reducer {
|
||||
void update_process_group(
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> new_process_group);
|
||||
|
||||
// Forces a rebuild of buckets on next iteration.
|
||||
void force_bucket_rebuild();
|
||||
// Resets reducer state.
|
||||
void reset_state();
|
||||
|
||||
protected:
|
||||
// Forward declaration.
|
||||
|
@ -2254,7 +2254,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# re-evaluates previous assumptions of buckets given the world size might have
|
||||
# changed.
|
||||
self._has_rebuilt_buckets = False
|
||||
self.reducer._force_bucket_rebuild()
|
||||
self.reducer._reset_state()
|
||||
|
||||
if not _rank_not_in_group(new_process_group):
|
||||
self.process_group = new_process_group
|
||||
|
@ -9566,79 +9566,124 @@ class DistributedTest:
|
||||
running = False
|
||||
t.join()
|
||||
|
||||
def _run_ddp_update_process_group(self, new_pg):
|
||||
def get_num_torch_recompiles():
|
||||
guard_failures = torch._dynamo.utils.guard_failures
|
||||
num_recompiles = [len(guard_failures[code]) for code in guard_failures]
|
||||
return 0 if len(num_recompiles) == 0 else max(num_recompiles)
|
||||
|
||||
class SimulateError(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise RuntimeError()
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
# 4MB for multiple buckets.
|
||||
self.fc1 = torch.nn.Linear(1024, 1024).cuda(device)
|
||||
self.fc2 = torch.nn.Linear(1024, 1024).cuda(device)
|
||||
self.fc3 = torch.nn.Linear(1024, 1024).cuda(device)
|
||||
|
||||
def forward(self, inp, error):
|
||||
if error:
|
||||
return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp))))
|
||||
else:
|
||||
return self.fc3(self.fc2(self.fc1(inp)))
|
||||
|
||||
|
||||
input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank)
|
||||
ddp = torch.nn.parallel.DistributedDataParallel(
|
||||
MyModel(self.rank),
|
||||
device_ids=[self.rank],
|
||||
find_unused_parameters=True,
|
||||
bucket_cap_mb=1,
|
||||
)
|
||||
model = torch.compile(ddp)
|
||||
|
||||
def run_iteration():
|
||||
# Run regular iteration.
|
||||
out = model(input, error=False)
|
||||
out.sum().backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Run with error.
|
||||
with self.assertRaises(RuntimeError):
|
||||
out = model(input, error=True)
|
||||
out.sum().backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
run_iteration()
|
||||
assert 0 == get_num_torch_recompiles()
|
||||
|
||||
if new_pg:
|
||||
# Now reduce world_size and run iteration.
|
||||
group_size_2 = dist.new_group(ranks=[0, 1])
|
||||
ddp._update_process_group(group_size_2)
|
||||
if self.rank in [0, 1]:
|
||||
run_iteration()
|
||||
|
||||
# Increase the world size and run iteration.
|
||||
group_size_3 = dist.new_group(ranks=[1, 2, 3])
|
||||
ddp._update_process_group(group_size_3)
|
||||
if self.rank in [1, 2, 3]:
|
||||
run_iteration()
|
||||
|
||||
# Back to default size.
|
||||
ddp._update_process_group(_get_default_group())
|
||||
run_iteration()
|
||||
else:
|
||||
# Create default pg of smaller size.
|
||||
dist.destroy_process_group()
|
||||
|
||||
if self.rank in [1, 2, 3]:
|
||||
dist.init_process_group(
|
||||
init_method=self.init_method,
|
||||
backend=BACKEND,
|
||||
world_size=3,
|
||||
rank=self.rank - 1,
|
||||
timeout=timedelta(seconds=default_pg_timeout),
|
||||
)
|
||||
ddp._update_process_group(_get_default_group())
|
||||
run_iteration()
|
||||
dist.destroy_process_group()
|
||||
|
||||
# Need a barrier here to ensure ranks 1, 2 and 3 are done.
|
||||
self._barrier(wait_for=4)
|
||||
|
||||
# Need to init pg again for "_barrier" to succeed.
|
||||
dist.init_process_group(
|
||||
init_method=self.init_method,
|
||||
backend=BACKEND,
|
||||
world_size=4,
|
||||
rank=self.rank,
|
||||
timeout=timedelta(seconds=default_pg_timeout),
|
||||
)
|
||||
|
||||
# Validate no more recompiles.
|
||||
assert 0 == get_num_torch_recompiles()
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@require_world_size(4)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
f"The {BACKEND} backend does not support DistributedDataParallel",
|
||||
)
|
||||
def test_ddp_update_process_group(self):
|
||||
def get_num_torch_recompiles():
|
||||
guard_failures = torch._dynamo.utils.guard_failures
|
||||
num_recompiles = [len(guard_failures[code]) for code in guard_failures]
|
||||
return 0 if len(num_recompiles) == 0 else max(num_recompiles)
|
||||
|
||||
input = torch.rand(10, 10).cuda(self.rank)
|
||||
ddp = torch.nn.parallel.DistributedDataParallel(
|
||||
torch.nn.Linear(10, 10).cuda(self.rank),
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
model = torch.compile(ddp)
|
||||
|
||||
def run_iteration():
|
||||
out = model(input)
|
||||
out.sum().backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Run regular iteration.
|
||||
run_iteration()
|
||||
num_compiles = get_num_torch_recompiles()
|
||||
assert 0 == num_compiles
|
||||
|
||||
# Now reduce world_size and run iteration.
|
||||
group_size_2 = dist.new_group(ranks=[0, 1])
|
||||
ddp._update_process_group(group_size_2)
|
||||
if self.rank in [0, 1]:
|
||||
run_iteration()
|
||||
|
||||
# Increase the world size and run iteration.
|
||||
group_size_3 = dist.new_group(ranks=[1, 2, 3])
|
||||
ddp._update_process_group(group_size_3)
|
||||
if self.rank in [1, 2, 3]:
|
||||
run_iteration()
|
||||
|
||||
# Back to default size.
|
||||
ddp._update_process_group(_get_default_group())
|
||||
run_iteration()
|
||||
|
||||
# Now create default pg of smaller size.
|
||||
dist.destroy_process_group()
|
||||
|
||||
if self.rank in [1, 2, 3]:
|
||||
dist.init_process_group(
|
||||
init_method=self.init_method,
|
||||
backend=BACKEND,
|
||||
world_size=3,
|
||||
rank=self.rank - 1,
|
||||
timeout=timedelta(seconds=default_pg_timeout),
|
||||
)
|
||||
ddp._update_process_group(_get_default_group())
|
||||
run_iteration()
|
||||
dist.destroy_process_group()
|
||||
|
||||
# Need to init pg again for "_barrier" to succeed.
|
||||
dist.init_process_group(
|
||||
init_method=self.init_method,
|
||||
backend=BACKEND,
|
||||
world_size=4,
|
||||
rank=self.rank,
|
||||
timeout=timedelta(seconds=default_pg_timeout),
|
||||
)
|
||||
|
||||
# Validate no more recompiles.
|
||||
num_compiles = get_num_torch_recompiles()
|
||||
assert 0 == num_compiles
|
||||
def test_ddp_update_process_group_new_group(self):
|
||||
self._run_ddp_update_process_group(new_pg=True)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@require_world_size(4)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
f"The {BACKEND} backend does not support DistributedDataParallel",
|
||||
)
|
||||
def test_ddp_update_process_group_default_group(self):
|
||||
self._run_ddp_update_process_group(new_pg=False)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
|
Reference in New Issue
Block a user