[DDP] Bucket handling: make first bucket size equal to bucket_cap_mb if it was set (#121640)

The fist DDP bucket is always being created of the size of `dist._DEFAULT_FIRST_BUCKET_BYTES` (1 MiB) by default regardless of `bucket_cap_mb`. The proposal is to set `bucket_cap_mb` as the one main bucket size if it was supplied by the user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121640
Approved by: https://github.com/wanchaol
This commit is contained in:
Aidyn-A
2024-06-05 23:44:51 +00:00
committed by PyTorch MergeBot
parent ffaea656b5
commit e98662bed9

View File

@ -548,7 +548,8 @@ class DistributedDataParallel(Module, Joinable):
multiple buckets so that gradient reduction of each
bucket can potentially overlap with backward computation.
:attr:`bucket_cap_mb` controls the bucket size in
MegaBytes (MB). (default: 25)
MebiBytes (MiB). If ``None``, a default size of 25 MiB
will be used. (default: ``None``)
find_unused_parameters (bool): Traverse the autograd graph from all
tensors contained in the return value of the
wrapped module's ``forward`` function. Parameters
@ -631,7 +632,7 @@ class DistributedDataParallel(Module, Joinable):
dim=0,
broadcast_buffers=True,
process_group=None,
bucket_cap_mb=25,
bucket_cap_mb=None,
find_unused_parameters=False,
check_reduction=False,
gradient_as_bucket_view=False,
@ -788,7 +789,14 @@ class DistributedDataParallel(Module, Joinable):
self.broadcast_bucket_size = int(250 * 1024 * 1024)
# reduction bucket size
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
if bucket_cap_mb is None:
# default case (bucket cap is 25 MiB)
self.bucket_bytes_cap_default = True
self.bucket_bytes_cap = int(25 * 1024 * 1024)
else:
self.bucket_bytes_cap_default = False
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
# Whether to perform input tensor CPU to GPU copies on a side-stream
self.use_side_stream_for_tensor_copies = (
os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1"
@ -1156,10 +1164,13 @@ class DistributedDataParallel(Module, Joinable):
if static_graph is True or self.find_unused_parameters is False:
bucket_size_limits = [sys.maxsize]
else:
bucket_size_limits = [
dist._DEFAULT_FIRST_BUCKET_BYTES,
self.bucket_bytes_cap,
]
if self.bucket_bytes_cap_default:
bucket_size_limits = [
dist._DEFAULT_FIRST_BUCKET_BYTES,
self.bucket_bytes_cap,
]
else:
bucket_size_limits = [self.bucket_bytes_cap]
(
bucket_indices,
per_bucket_size_limits,
@ -1195,7 +1206,9 @@ class DistributedDataParallel(Module, Joinable):
param_to_name_mapping,
# User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
# bucket.
dist._DEFAULT_FIRST_BUCKET_BYTES,
dist._DEFAULT_FIRST_BUCKET_BYTES
if self.bucket_bytes_cap_default
else self.bucket_bytes_cap,
)
self.logger = dist.Logger(self.reducer)