typing proxy_tensor.py (#129182)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129182
Approved by: https://github.com/Chillee
This commit is contained in:
Aaron Orenstein
2024-07-12 08:19:14 -07:00
committed by PyTorch MergeBot
parent ea78b0c177
commit 634b62f111
18 changed files with 758 additions and 338 deletions

View File

@ -645,7 +645,7 @@ class DistributedDataParallel(Module, Joinable):
):
super().__init__()
Joinable.__init__(self)
self.logger = None
self.logger: Optional[dist.Logger] = None
if bool(delay_all_reduce_named_params is not None) != bool(
param_to_hook_all_reduce is not None
):
@ -1207,9 +1207,11 @@ class DistributedDataParallel(Module, Joinable):
param_to_name_mapping,
# User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
# bucket.
dist._DEFAULT_FIRST_BUCKET_BYTES
if self.bucket_bytes_cap_default
else self.bucket_bytes_cap,
(
dist._DEFAULT_FIRST_BUCKET_BYTES
if self.bucket_bytes_cap_default
else self.bucket_bytes_cap
),
)
self.logger = dist.Logger(self.reducer)