Bug fixes to DDP _update_process_group API. (#114194)

https://github.com/pytorch/pytorch/pull/113580 introduced the `DDP._update_process_group` API. However, the implementation did not correctly reset all of the necessary state in the reducer. In particular if an error occurred during backward, DDP would end up in an incorrect state.

As a result, in this PR I've enhanced the unit test to test for this case and also appropriately fixed resetting Reducer state.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114194
Approved by: https://github.com/rohan-varma
This commit is contained in:
Pritam Damania
2023-11-27 23:52:35 +00:00
committed by PyTorch MergeBot
parent 7c98bac4a0
commit f505d76462
6 changed files with 127 additions and 75 deletions

View File

@ -66,7 +66,7 @@ class Reducer:
def _remove_autograd_hooks(self) -> None: ...
def _check_reducer_finalized(self) -> None: ...
def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ...
def _force_bucket_rebuild(self) -> None: ...
def _reset_state(self) -> None: ...
def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
class DDPLoggingData:

View File

@ -598,10 +598,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
[](::c10d::Reducer& reducer) { return reducer.check_finalized(); },
py::call_guard<py::gil_scoped_release>())
.def(
"_force_bucket_rebuild",
[](::c10d::Reducer& reducer) {
return reducer.force_bucket_rebuild();
},
"_reset_state",
[](::c10d::Reducer& reducer) { return reducer.reset_state(); },
py::call_guard<py::gil_scoped_release>())
.def(
"_update_process_group",

View File

@ -2314,11 +2314,20 @@ void Reducer::update_process_group(
process_group_ = std::move(new_process_group);
}
void Reducer::force_bucket_rebuild() {
void Reducer::reset_state() {
std::lock_guard<std::mutex> lock(mutex_);
// Force rebuild of buckets.
has_rebuilt_bucket_ = false;
rebuilt_params_.clear();
rebuilt_param_indices_.clear();
// Ensure forward can run despite previous backward not succeeding.
expect_autograd_hooks_ = false;
require_finalize_ = false;
// Unset allreduce division factor, as it may change in next backwards pass
// when running with DDP join mode.
div_factor_ = kUnsetDivFactor;
}
} // namespace c10d

View File

@ -194,8 +194,8 @@ class TORCH_API Reducer {
void update_process_group(
c10::intrusive_ptr<c10d::ProcessGroup> new_process_group);
// Forces a rebuild of buckets on next iteration.
void force_bucket_rebuild();
// Resets reducer state.
void reset_state();
protected:
// Forward declaration.

View File

@ -2254,7 +2254,7 @@ class DistributedDataParallel(Module, Joinable):
# re-evaluates previous assumptions of buckets given the world size might have
# changed.
self._has_rebuilt_buckets = False
self.reducer._force_bucket_rebuild()
self.reducer._reset_state()
if not _rank_not_in_group(new_process_group):
self.process_group = new_process_group

View File

@ -9566,79 +9566,124 @@ class DistributedTest:
running = False
t.join()
def _run_ddp_update_process_group(self, new_pg):
def get_num_torch_recompiles():
guard_failures = torch._dynamo.utils.guard_failures
num_recompiles = [len(guard_failures[code]) for code in guard_failures]
return 0 if len(num_recompiles) == 0 else max(num_recompiles)
class SimulateError(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
raise RuntimeError()
class MyModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
# 4MB for multiple buckets.
self.fc1 = torch.nn.Linear(1024, 1024).cuda(device)
self.fc2 = torch.nn.Linear(1024, 1024).cuda(device)
self.fc3 = torch.nn.Linear(1024, 1024).cuda(device)
def forward(self, inp, error):
if error:
return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp))))
else:
return self.fc3(self.fc2(self.fc1(inp)))
input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank)
ddp = torch.nn.parallel.DistributedDataParallel(
MyModel(self.rank),
device_ids=[self.rank],
find_unused_parameters=True,
bucket_cap_mb=1,
)
model = torch.compile(ddp)
def run_iteration():
# Run regular iteration.
out = model(input, error=False)
out.sum().backward()
torch.cuda.synchronize()
# Run with error.
with self.assertRaises(RuntimeError):
out = model(input, error=True)
out.sum().backward()
torch.cuda.synchronize()
run_iteration()
assert 0 == get_num_torch_recompiles()
if new_pg:
# Now reduce world_size and run iteration.
group_size_2 = dist.new_group(ranks=[0, 1])
ddp._update_process_group(group_size_2)
if self.rank in [0, 1]:
run_iteration()
# Increase the world size and run iteration.
group_size_3 = dist.new_group(ranks=[1, 2, 3])
ddp._update_process_group(group_size_3)
if self.rank in [1, 2, 3]:
run_iteration()
# Back to default size.
ddp._update_process_group(_get_default_group())
run_iteration()
else:
# Create default pg of smaller size.
dist.destroy_process_group()
if self.rank in [1, 2, 3]:
dist.init_process_group(
init_method=self.init_method,
backend=BACKEND,
world_size=3,
rank=self.rank - 1,
timeout=timedelta(seconds=default_pg_timeout),
)
ddp._update_process_group(_get_default_group())
run_iteration()
dist.destroy_process_group()
# Need a barrier here to ensure ranks 1, 2 and 3 are done.
self._barrier(wait_for=4)
# Need to init pg again for "_barrier" to succeed.
dist.init_process_group(
init_method=self.init_method,
backend=BACKEND,
world_size=4,
rank=self.rank,
timeout=timedelta(seconds=default_pg_timeout),
)
# Validate no more recompiles.
assert 0 == get_num_torch_recompiles()
@skip_if_lt_x_gpu(4)
@require_world_size(4)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["ddp"],
f"The {BACKEND} backend does not support DistributedDataParallel",
)
def test_ddp_update_process_group(self):
def get_num_torch_recompiles():
guard_failures = torch._dynamo.utils.guard_failures
num_recompiles = [len(guard_failures[code]) for code in guard_failures]
return 0 if len(num_recompiles) == 0 else max(num_recompiles)
input = torch.rand(10, 10).cuda(self.rank)
ddp = torch.nn.parallel.DistributedDataParallel(
torch.nn.Linear(10, 10).cuda(self.rank),
device_ids=[self.rank],
)
model = torch.compile(ddp)
def run_iteration():
out = model(input)
out.sum().backward()
torch.cuda.synchronize()
# Run regular iteration.
run_iteration()
num_compiles = get_num_torch_recompiles()
assert 0 == num_compiles
# Now reduce world_size and run iteration.
group_size_2 = dist.new_group(ranks=[0, 1])
ddp._update_process_group(group_size_2)
if self.rank in [0, 1]:
run_iteration()
# Increase the world size and run iteration.
group_size_3 = dist.new_group(ranks=[1, 2, 3])
ddp._update_process_group(group_size_3)
if self.rank in [1, 2, 3]:
run_iteration()
# Back to default size.
ddp._update_process_group(_get_default_group())
run_iteration()
# Now create default pg of smaller size.
dist.destroy_process_group()
if self.rank in [1, 2, 3]:
dist.init_process_group(
init_method=self.init_method,
backend=BACKEND,
world_size=3,
rank=self.rank - 1,
timeout=timedelta(seconds=default_pg_timeout),
)
ddp._update_process_group(_get_default_group())
run_iteration()
dist.destroy_process_group()
# Need to init pg again for "_barrier" to succeed.
dist.init_process_group(
init_method=self.init_method,
backend=BACKEND,
world_size=4,
rank=self.rank,
timeout=timedelta(seconds=default_pg_timeout),
)
# Validate no more recompiles.
num_compiles = get_num_torch_recompiles()
assert 0 == num_compiles
def test_ddp_update_process_group_new_group(self):
self._run_ddp_update_process_group(new_pg=True)
@skip_if_lt_x_gpu(4)
@require_world_size(4)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["ddp"],
f"The {BACKEND} backend does not support DistributedDataParallel",
)
def test_ddp_update_process_group_default_group(self):
self._run_ddp_update_process_group(new_pg=False)
@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(