diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index faff37add1e9..cb8157fac625 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -59,7 +59,6 @@ USE_BLACK_FILELIST = re.compile( # torch/[a-c]*/** "torch/[a-c]*/**", # torch/d*/** - "torch/d*/**", # torch/[e-n]*/** "torch/[e-n]*/**", # torch/optim/** diff --git a/torch/distributed/_composable/contract.py b/torch/distributed/_composable/contract.py index 8005415179a3..56ada8791ebf 100644 --- a/torch/distributed/_composable/contract.py +++ b/torch/distributed/_composable/contract.py @@ -36,11 +36,9 @@ _M = TypeVar("_M", nn.Module, list[nn.Module]) class _ContractFn(Protocol, Generic[_P, _T, _TState]): - def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: - ... + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ... - def state(self, module: nn.Module) -> _TState: - ... + def state(self, module: nn.Module) -> _TState: ... def contract( @@ -92,7 +90,7 @@ def contract( # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package @wraps(state_cls) # type: ignore[arg-type] def inner( - func: Callable[Concatenate[_M, _P], _M] + func: Callable[Concatenate[_M, _P], _M], ) -> _ContractFn[Concatenate[_M, _P], _M, _TState]: @wraps(func) def wrapper( @@ -232,9 +230,7 @@ def contract( return module.__dict__.setdefault( # type: ignore[call-overload] STATE_KEY, {}, # TODO(@yhcharles): this is a temporary fix, need a better way - ).get( - func - ) # type: ignore[call-overload] + ).get(func) # type: ignore[call-overload] wrapper.state = get_state # type: ignore[attr-defined] diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 41f55756a1de..921d875455f7 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -274,9 +274,9 @@ def reduce_scatter_tensor( group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) - assert ( - self.size(scatter_dim) % group_size == 0 - ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + assert self.size(scatter_dim) % group_size == 0, ( + f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + ) if scatter_dim != 0: tensor_list = torch.chunk(self, group_size, dim=scatter_dim) self = torch.cat(tensor_list) @@ -313,9 +313,9 @@ def reduce_scatter_tensor_autograd( group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) - assert ( - self.size(scatter_dim) % group_size == 0 - ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + assert self.size(scatter_dim) % group_size == 0, ( + f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + ) if scatter_dim != 0: tensor_list = torch.chunk(self, group_size, dim=scatter_dim) self = torch.cat(tensor_list) @@ -414,9 +414,9 @@ def reduce_scatter_tensor_coalesced( assert len(scatter_dim) == len(inputs) for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): - assert ( - tensor.size(dim) % group_size == 0 - ), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" + assert tensor.size(dim) % group_size == 0, ( + f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" + ) if dim != 0: tensor_list = torch.chunk(tensor, group_size, dim=dim) inputs[idx] = torch.cat(tensor_list) @@ -574,6 +574,7 @@ class AsyncCollectiveTensor(torch.Tensor): tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size) return _maybe_wrap_tensor(tensor) """ + elem: torch.Tensor completed: bool @@ -726,9 +727,9 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int group_size = len(rankset) tag = tag or c10d._get_group_tag(group) elif isinstance(group, DeviceMesh): - assert ( - group.ndim == 1 - ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + assert group.ndim == 1, ( + "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + ) # TODO: it should run collective in the whole mesh instead of dim 0 tag, rankset, _ = group._dim_group_infos[0] group_size = len(rankset) @@ -763,9 +764,9 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: elif isinstance(group, str): return group elif isinstance(group, DeviceMesh): - assert ( - group.ndim == 1 - ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + assert group.ndim == 1, ( + "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + ) return group._dim_group_infos[0][2] elif isinstance(group, tuple): if ( @@ -837,11 +838,13 @@ def allow_inflight_collective_as_graph_input_ctx(value: bool = True): req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True) return y + @torch.compile(fullgraph=True) def all_reduce_wait_compiled(y): torch.ops.c10d_functional.wait_tensor(y) return y * y + x = torch.ones(1280, 1280, device="cuda") + self.rank # the context manager ensures that `wait_tensor(y)` will wait on the correct work object with allow_inflight_collective_as_graph_input_ctx(): @@ -1057,9 +1060,9 @@ def all_gather_tensor_inplace( tag: str = "", gather_dim: int = 0, ): - assert ( - not async_op - ), "Can't remap async version of inplace op to functional collective" + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD assert group is not None @@ -1076,9 +1079,9 @@ def reduce_scatter_tensor_inplace( scatter_dim: int = 0, tag: str = "", ): - assert ( - not async_op - ), "Can't remap async version of inplace op to functional collective" + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD assert group is not None @@ -1105,9 +1108,9 @@ def all_reduce_inplace( async_op: bool = False, tag: str = "", ): - assert ( - not async_op - ), "Can't remap async version of inplace op to functional collective" + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD assert group is not None @@ -1124,9 +1127,9 @@ def all_to_all_inplace( async_op=False, tag: str = "", ): - assert ( - not async_op - ), "Can't remap async version of inplace op to functional collective" + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD assert group is not None @@ -1149,12 +1152,12 @@ def all_gather_inplace( async_op=False, tag: str = "", ): - assert ( - not async_op - ), "Can't remap async version of inplace op to functional collective" - assert all( - t.size(0) == tensor.size(0) for t in tensor_list - ), "Remapping variable size all_gather is not yet supported" + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) + assert all(t.size(0) == tensor.size(0) for t in tensor_list), ( + "Remapping variable size all_gather is not yet supported" + ) group = group or dist.group.WORLD assert group is not None diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index f816fb2af0d8..e146a3598561 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -592,7 +592,9 @@ class ShardedTensor(ShardedTensorBase): assert ( isinstance(device, torch.device) and device.index == torch.cuda.current_device() - ), """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!""" + ), ( + """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!""" + ) current_device = torch.device(torch.cuda.current_device()) # returns a copy of ShardedTensor on CUDA current device @@ -831,7 +833,9 @@ class ShardedTensor(ShardedTensorBase): "rank:1/cuda:1", ], ) - >>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4]) + >>> st = ShardedTensor._init_from_local_tensor( + ... local_tensor, sharding_spec, [2, 4] + ... ) >>> st ShardedTensor( ShardedTensorMetadata( diff --git a/torch/distributed/_shard/sharded_tensor/reshard.py b/torch/distributed/_shard/sharded_tensor/reshard.py index 30505721ff91..daef9c358618 100644 --- a/torch/distributed/_shard/sharded_tensor/reshard.py +++ b/torch/distributed/_shard/sharded_tensor/reshard.py @@ -219,9 +219,7 @@ def reshard_local_shard( output_tensor_size = list(st_size) output_tensor_size[current_sharding_dim] = sharded_dim_size output_tensor_size[reshard_dim] = input_split_sizes[current_rank] - output_tensor_list[ - placement.rank() - ] = torch.empty( # type: ignore[union-attr, index] + output_tensor_list[placement.rank()] = torch.empty( # type: ignore[union-attr, index] output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype ) indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type] diff --git a/torch/distributed/_sharded_tensor/__init__.py b/torch/distributed/_sharded_tensor/__init__.py index 5e6f4d2a1a6e..24de2628c0ab 100644 --- a/torch/distributed/_sharded_tensor/__init__.py +++ b/torch/distributed/_sharded_tensor/__init__.py @@ -16,6 +16,6 @@ with warnings.catch_warnings(): stacklevel=2, ) -sys.modules[ - "torch.distributed._sharded_tensor" -] = torch.distributed._shard.sharded_tensor +sys.modules["torch.distributed._sharded_tensor"] = ( + torch.distributed._shard.sharded_tensor +) diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index b4c1b908c192..640922762386 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -67,7 +67,7 @@ def _all_gather_sharded_tensor( class CompanionMismatch(Exception): - ... + pass def _iterate_state_dict( @@ -409,9 +409,9 @@ def _create_cpu_state_dict( def unpin_memory(t): succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr())) - assert ( - succ == 0 - ), f"Unpinning shared memory failed with error-code: {succ}" + assert succ == 0, ( + f"Unpinning shared memory failed with error-code: {succ}" + ) weakref.finalize(t, unpin_memory, t) succ = int( @@ -421,9 +421,9 @@ def _create_cpu_state_dict( 1, # lines up with 'cudaHostRegisterPortable' ) ) - assert ( - succ == 0 - ), f"Pinning shared memory failed with error-code: {succ}" + assert succ == 0, ( + f"Pinning shared memory failed with error-code: {succ}" + ) return t elif pin_memory: return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 62299e2fa5d3..6f8f0b68fb5c 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -1525,8 +1525,7 @@ if TYPE_CHECKING: @overload def empty( *size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None -) -> torch.Tensor: - ... +) -> torch.Tensor: ... @overload @@ -1535,8 +1534,7 @@ def empty( *, dtype: Optional[_dtype] = None, device: Optional[_device] = None, -) -> torch.Tensor: - ... +) -> torch.Tensor: ... def empty( # type: ignore[misc] diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index 40f9727015d7..c5559cc10fab 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -6,6 +6,7 @@ we keep the old import path starts with `_tensor` for backward compatibility. We will remove this folder once we resolve all the BC issues. """ + import sys from importlib import import_module diff --git a/torch/distributed/_tools/fsdp2_mem_tracker.py b/torch/distributed/_tools/fsdp2_mem_tracker.py index c7a67ebee3de..1c03034c52c4 100644 --- a/torch/distributed/_tools/fsdp2_mem_tracker.py +++ b/torch/distributed/_tools/fsdp2_mem_tracker.py @@ -153,7 +153,7 @@ class FSDPMemTracker(MemTracker): loss.backward() optimizer.step() fmt.display_snapshot("peak") - fmt.display_modulewise_snapshots(depth = 3, units = "MB") + fmt.display_modulewise_snapshots(depth=3, units="MB") """ diff --git a/torch/distributed/_tools/mem_tracker.py b/torch/distributed/_tools/mem_tracker.py index b72987af6f75..1416fa992383 100644 --- a/torch/distributed/_tools/mem_tracker.py +++ b/torch/distributed/_tools/mem_tracker.py @@ -379,7 +379,7 @@ class MemTracker(TorchDispatchMode): optimizer.step() optimizer.zero_grad() mt.display_snapshot("peak") - mt.display_modulewise_snapshots(depth = 3, units = "MiB") + mt.display_modulewise_snapshots(depth=3, units="MiB") Known Limitations: - The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``. diff --git a/torch/distributed/_tools/mod_tracker.py b/torch/distributed/_tools/mod_tracker.py index 6c4aabbb6d17..45e2a4f95710 100644 --- a/torch/distributed/_tools/mod_tracker.py +++ b/torch/distributed/_tools/mod_tracker.py @@ -42,6 +42,7 @@ class ModTracker: def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias + torch.nn.functional.linear = my_linear mod(torch.rand(2, 2)) diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index 37ac5944d527..5dabb23b6347 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -255,9 +255,9 @@ class RuntimeEstimator(TorchDispatchMode): Tuple[Any, float]: A tuple containing the result of the function and the mean operation time in milliseconds. """ - assert isinstance( - cls.fake_mode, FakeTensorMode - ), "Initialize/Assign FakeTensorMode before using this function" + assert isinstance(cls.fake_mode, FakeTensorMode), ( + "Initialize/Assign FakeTensorMode before using this function" + ) mean_op_time = 0.0 if func._overloadpacket not in _VIEW_OPS: try: @@ -289,9 +289,9 @@ class RuntimeEstimator(TorchDispatchMode): Tuple[Any, float]: A tuple containing the result of the function and the mean operation time in milliseconds. """ - assert ( - torch.cuda.is_available() - ), "Roofline estimation needs to access CUDA capabilities to make estimations" + assert torch.cuda.is_available(), ( + "Roofline estimation needs to access CUDA capabilities to make estimations" + ) def get_num_bytes(t: torch.Tensor) -> int: """ @@ -324,9 +324,9 @@ class RuntimeEstimator(TorchDispatchMode): float: The estimated compute time in nanoseconds. """ if func_packet in flop_registry: - assert ( - len(out_dtypes) == 1 - ), f"Only support single out dtype got {out_dtypes} for {func_packet}" + assert len(out_dtypes) == 1, ( + f"Only support single out dtype got {out_dtypes} for {func_packet}" + ) dtype = out_dtypes.pop() # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s peak_gpu_flops = get_device_tflops(dtype) * 1e15 @@ -487,9 +487,9 @@ class RuntimeEstimator(TorchDispatchMode): def __enter__(self) -> Self: fake_mode = active_fake_mode() - assert isinstance( - fake_mode, FakeTensorMode - ), "No FakeTensorMode found, designed to used under FakeTensorMode" + assert isinstance(fake_mode, FakeTensorMode), ( + "No FakeTensorMode found, designed to used under FakeTensorMode" + ) RuntimeEstimator.fake_mode = fake_mode self.total_runtime = 0.0 self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0)) diff --git a/torch/distributed/_tools/sac_estimator.py b/torch/distributed/_tools/sac_estimator.py index c5cecddc1680..ac0d1cf04cf3 100644 --- a/torch/distributed/_tools/sac_estimator.py +++ b/torch/distributed/_tools/sac_estimator.py @@ -245,7 +245,7 @@ class SACEstimator(TorchDispatchMode): with FakeTensorMode(): module = ... inp = ... - with sac_estimator('operator-level-cost-model'): + with sac_estimator("operator-level-cost-model"): output = module(inp) sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True) """ @@ -442,9 +442,9 @@ class SACEstimator(TorchDispatchMode): out_storages_cpu.update(_get_untyped_storages(o)) # Check if there's more than 1 CUDA device - assert ( - len(cuda_devices) <= 1 - ), f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}" + assert len(cuda_devices) <= 1, ( + f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}" + ) # 2. Get the memory consumed by output nbytes_cuda = sum( @@ -484,9 +484,9 @@ class SACEstimator(TorchDispatchMode): if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): acm_stats.sac_metadata.append(acm) else: - assert ( - mod_fqn == "Global" - ), f"Module {mod_fqn} not found in AC Mod Stats" + assert mod_fqn == "Global", ( + f"Module {mod_fqn} not found in AC Mod Stats" + ) self._sac_metadata.append(acm) return out @@ -979,9 +979,9 @@ class SACEstimator(TorchDispatchMode): def __enter__(self) -> Self: # type: ignore[no-untyped-def] fake_mode = active_fake_mode() - assert isinstance( - fake_mode, FakeTensorMode - ), "SAC Estimator should be called in FakeTensorMode" + assert isinstance(fake_mode, FakeTensorMode), ( + "SAC Estimator should be called in FakeTensorMode" + ) RuntimeEstimator.fake_mode = fake_mode self._mod_tracker.register_user_hooks( pre_fw_hook=self._pre_fw_hook, diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index b7e60f2022b0..2a08212dfa9c 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -38,9 +38,9 @@ def _perform_local_step( """ overlap_info = zero._overlap_info bucket_index = bucket.index() - assert ( - len(zero.optim.param_groups) == 1 - ), "Overlapping DDP with ZeRO only supports a single parameter group" + assert len(zero.optim.param_groups) == 1, ( + "Overlapping DDP with ZeRO only supports a single parameter group" + ) # Construct the `gradients` input for the local optimizer step, which # expects `None` in a list position to indicate that the corresponding @@ -49,9 +49,9 @@ def _perform_local_step( gradients: list[Optional[torch.Tensor]] = [ _NO_PARAM_UPDATE for _ in range(num_local_optim_params) ] - assert ( - bucket_index in overlap_info.offsets - ), f"Bucket index {bucket_index} was not assigned to rank {rank}" + assert bucket_index in overlap_info.offsets, ( + f"Bucket index {bucket_index} was not assigned to rank {rank}" + ) gradients_offset = overlap_info.offsets[bucket_index] bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index] bucket_offset = bucket_assignment.offset @@ -77,13 +77,13 @@ def _broadcast_bucket( :class:`ZeroRedundancyOptimizer` instance. """ overlap_info = zero._overlap_info - assert ( - len(overlap_info.assigned_ranks_per_bucket) > bucket_index - ), "`assigned_ranks_per_bucket` is not fully constructed" + assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, ( + "`assigned_ranks_per_bucket` is not fully constructed" + ) # Sort to ensure the same ordering across ranks assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index]) assert len(assigned_ranks) > 0, ( - f"Bucket {bucket_index} should be " "assigned to at least one rank" + f"Bucket {bucket_index} should be assigned to at least one rank" ) for assigned_rank in assigned_ranks: bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank] @@ -273,9 +273,9 @@ def hook_with_zero_step( rank = zero.global_rank assert overlap_info.status == _OverlapStatus.INITIALIZED - assert ( - len(overlap_info.assigned_ranks_per_bucket) > bucket_index - ), "`assigned_ranks_per_bucket` is not fully constructed" + assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, ( + "`assigned_ranks_per_bucket` is not fully constructed" + ) assigned_to_bucket = ( rank in overlap_info.assigned_ranks_per_bucket[bucket_index] ) @@ -288,9 +288,9 @@ def hook_with_zero_step( # Check that buckets are indexed incrementally starting from 0 in the # order of their autograd hooks firing if len(overlap_info.bucket_indices_seen) > 0: - assert ( - overlap_info.bucket_indices_seen[-1] == bucket_index - 1 - ), "Bucket indices are not in incremental order" + assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, ( + "Bucket indices are not in incremental order" + ) else: assert bucket_index == 0, "Bucket indices do not start from 0" overlap_info.bucket_indices_seen.append(bucket_index) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index a8bcd8671a4e..a64b502255f6 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -129,7 +129,7 @@ def bf16_compress_hook( def fp16_compress_wrapper( - hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]] + hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: """ Cast input tensor to ``torch.float16``, cast result of hook back to input dtype. @@ -167,7 +167,7 @@ def fp16_compress_wrapper( def bf16_compress_wrapper( - hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]] + hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: """ Warning: This API is experimental, and it requires NCCL version later than 2.9.6. diff --git a/torch/distributed/algorithms/join.py b/torch/distributed/algorithms/join.py index 2bdbb5fff42b..70d74af7ead0 100644 --- a/torch/distributed/algorithms/join.py +++ b/torch/distributed/algorithms/join.py @@ -223,8 +223,7 @@ class Join: self._rank = dist.get_rank(self._process_group) self._device = device - def __enter__(self): - ... + def __enter__(self): ... def __exit__( self, diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index 0438043a6e74..407014418ecc 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -52,7 +52,10 @@ def average_parameters( def get_params_to_average( - params: Union[Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]] + params: Union[ + Iterable[torch.nn.Parameter], + Iterable[dict[str, torch.nn.Parameter]], + ], ): """ Return a list of parameters that need to average. diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index 56af56743d2f..d7fa7ac66745 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -550,9 +550,7 @@ def create_default_global_save_plan( new_item = dataclasses.replace(item, index=new_index) new_items.append(new_item) - assert ( - item.tensor_data.chunk is not None - ), f""" + assert item.tensor_data.chunk is not None, f""" Cannot create MD for tensor without bounds. FQN: {item.index.fqn} """ diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 626b5cd48603..025ab0a178cb 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -414,41 +414,33 @@ class FileSystemBase(ABC): @abstractmethod def create_stream( self, path: Union[str, os.PathLike], mode: str - ) -> Generator[io.IOBase, None, None]: - ... + ) -> Generator[io.IOBase, None, None]: ... @abstractmethod def concat_path( self, path: Union[str, os.PathLike], suffix: str - ) -> Union[str, os.PathLike]: - ... + ) -> Union[str, os.PathLike]: ... @abstractmethod def rename( self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] - ) -> None: - ... + ) -> None: ... @abstractmethod - def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: - ... + def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: ... @abstractmethod - def mkdir(self, path: Union[str, os.PathLike]) -> None: - ... + def mkdir(self, path: Union[str, os.PathLike]) -> None: ... @classmethod @abstractmethod - def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: - ... + def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: ... @abstractmethod - def exists(self, path: Union[str, os.PathLike]) -> bool: - ... + def exists(self, path: Union[str, os.PathLike]) -> bool: ... @abstractmethod - def rm_file(self, path: Union[str, os.PathLike]) -> None: - ... + def rm_file(self, path: Union[str, os.PathLike]) -> None: ... class FileSystem(FileSystemBase): @@ -512,7 +504,6 @@ class FileSystem(FileSystemBase): class _FileSystemWriter(StorageWriter): - """ Basic implementation of StorageWriter using file IO. @@ -800,9 +791,9 @@ class FileSystemReader(StorageReader): ) target_tensor = planner.resolve_tensor(req).detach() - assert ( - target_tensor.size() == tensor.size() - ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + assert target_tensor.size() == tensor.size(), ( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) target_tensor.copy_(tensor) planner.commit_tensor(req, target_tensor) diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py index add231f1441c..43193afe6e67 100644 --- a/torch/distributed/checkpoint/optimizer.py +++ b/torch/distributed/checkpoint/optimizer.py @@ -135,12 +135,12 @@ def _get_state_dict_2d_layout( for key, value in state_dict.items(): specs[key] = (None, value.size()) if _is_nested_tensor(value): - assert ( - len(value.local_shards()) == 1 - ), "Cannot handle ST with multiple shards" - assert isinstance( - value, ShardedTensor - ), "Can only handle nested ShardedTensor" + assert len(value.local_shards()) == 1, ( + "Cannot handle ST with multiple shards" + ) + assert isinstance(value, ShardedTensor), ( + "Can only handle nested ShardedTensor" + ) shard = value.local_shards()[0] specs[key] = ( shard.metadata.shard_offsets, diff --git a/torch/distributed/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py index e6513e53c110..bc0b26dfe4d0 100644 --- a/torch/distributed/checkpoint/planner.py +++ b/torch/distributed/checkpoint/planner.py @@ -151,7 +151,7 @@ class SavePlanner(abc.ABC): >>> storage_meta: Optional[StorageMeta], >>> is_coordinator: bool, >>> ) -> None: - >>> # prefix all keys with `foo_`` + >>> # prefix all keys with `foo_`` >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator) Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted @@ -175,8 +175,8 @@ class SavePlanner(abc.ABC): >>> from itertools import zip_longest >>> from dataclasses import replace >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): - >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 - >>> # This sample doesn't handle ShardedTensors + >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 + >>> # This sample doesn't handle ShardedTensors >>> def create_global_plan(self, all_plans): >>> iters = [iter(all_plans[0].items)] * len(all_plans) >>> items_per_rank = [ @@ -347,7 +347,7 @@ class LoadPlanner: >>> self.is_coordinator = is_coordinator >>> >>> def load_bytes(self, read_item, value): - >>> # Remove the "foo_" prefix + >>> # Remove the "foo_" prefix >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 5312f07d06ec..4473927f9d21 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -140,10 +140,12 @@ class StateDictOptions: @dataclass class _StateDictInfo(StateDictOptions): fqn_param_mapping: dict[ - Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] + Union[str, torch.Tensor], + Union[FQNS_T, torch.Tensor], ] = field(default_factory=dict) shared_params_mapping: dict[ - Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] + Union[str, torch.Tensor], + Union[FQNS_T, torch.Tensor], ] = field(default_factory=dict) submodule_prefixes: set[str] = field(default_factory=set) handle_model: bool = True @@ -1140,7 +1142,9 @@ def get_state_dict( >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) - >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim) + >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict( + ... fsdp_model, fsdp_optim + ... ) >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >>> # the asserts will fail. diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index 5221f5ff8ed7..6e9f8b83e2b0 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -125,7 +125,9 @@ def load( >>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() - >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1") + >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader( + ... "/checkpoint/1" + ... ) >>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index 4d3ae2d9b0f2..e36357f9a65c 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -127,7 +127,9 @@ def save( >>> state_dict = {"model": my_model} - >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") + >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( + ... "/checkpoint/1" + ... ) >>> torch.distributed.checkpoint.save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, @@ -206,7 +208,9 @@ def async_save( >>> state_dict = {"model": my_model} - >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") + >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter( + ... "/checkpoint/1" + ... ) >>> checkpoint_future = torch.distributed.checkpoint.async_save( >>> state_dict=state_dict, >>> storage_writer=fs_storage_writer, @@ -223,7 +227,9 @@ def async_save( pg = process_group or _get_default_group() assert ( torch.device("cpu") in pg._device_types # type: ignore[attr-defined] - ), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" + ), ( + "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" + ) storage_writer = cast( StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index cd483f8a7792..d4e2533fb3d9 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -32,7 +32,7 @@ R = TypeVar("R") def _get_failure_dict( - results: list[Union[T, WRAPPED_EXCEPTION]] + results: list[Union[T, WRAPPED_EXCEPTION]], ) -> dict[int, WRAPPED_EXCEPTION]: return cast( dict[int, WRAPPED_EXCEPTION], diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index d2a79a4d1c5f..ba603220ce0c 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -221,8 +221,12 @@ else: if cur_rank in mesh_nd: res_flattened_mesh = flattened_mesh self.child_to_root_mapping[res_flattened_mesh] = root_mesh # type: ignore[possibly-undefined] - self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = res_flattened_mesh # type: ignore[possibly-undefined] - self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple(flatten_dims_in_root) # type: ignore[possibly-undefined] + self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = ( + res_flattened_mesh # type: ignore[possibly-undefined] + ) + self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple( + flatten_dims_in_root + ) # type: ignore[possibly-undefined] return res_flattened_mesh @@ -242,9 +246,9 @@ else: root_mesh = self.get_root_mesh(device_mesh) child_mesh_dim_names = device_mesh.mesh_dim_names if root_mesh and child_mesh_dim_names: - assert ( - len(child_mesh_dim_names) == 1 - ), "The submesh can only be a 1D mesh." + assert len(child_mesh_dim_names) == 1, ( + "The submesh can only be a 1D mesh." + ) child_mesh_dim_name = child_mesh_dim_names[0] return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name) return None @@ -763,7 +767,9 @@ else: root_mesh, None ) if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys(): - dim_group_infos = root_to_flatten_mapping[mesh_dim]._dim_group_infos[0][:2] # type: ignore[index] + dim_group_infos = root_to_flatten_mapping[ + mesh_dim # type: ignore[index] + ]._dim_group_infos[0][:2] return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos)) else: mesh_dim = ( @@ -905,9 +911,9 @@ else: mesh_dim = 0 mesh_dim_group = not_none(self.get_group(mesh_dim)) - assert isinstance( - mesh_dim_group, ProcessGroup - ), "We expect ProcessGroup before calling `get_rank`!" + assert isinstance(mesh_dim_group, ProcessGroup), ( + "We expect ProcessGroup before calling `get_rank`!" + ) return not_none(get_rank(mesh_dim_group)) def get_coordinate(self) -> Optional[list[int]]: diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 26ec5fecb27c..b3557be6070d 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -334,12 +334,12 @@ class Backend(str): # noqa: SLOT000 # Allow UCC plugin if Pytorch is not built with native support. # TODO: remove this exception once UCC plugin is fully deprecated. if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()): - assert not hasattr( - Backend, name.upper() - ), f"{name.upper()} c10d backend already exist" - assert ( - name.upper() not in Backend._plugins - ), f"{name.upper()} c10d backend creator function already exist" + assert not hasattr(Backend, name.upper()), ( + f"{name.upper()} c10d backend already exist" + ) + assert name.upper() not in Backend._plugins, ( + f"{name.upper()} c10d backend creator function already exist" + ) setattr(Backend, name.upper(), name.lower()) Backend.backend_list.append(name.lower()) @@ -1650,9 +1650,9 @@ def init_process_group( if "torch._dynamo" in sys.modules: torch._dynamo.trace_rules.clear_lru_cache() - assert (store is None) or ( - init_method is None - ), "Cannot specify both init_method and store." + assert (store is None) or (init_method is None), ( + "Cannot specify both init_method and store." + ) if store is not None: assert world_size > 0, "world_size must be positive if using store" @@ -1734,7 +1734,10 @@ def init_process_group( ) _update_default_pg(default_pg) - _world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index] + _world.pg_group_ranks[GroupMember.WORLD] = { # type: ignore[index] + i: i + for i in range(GroupMember.WORLD.size()) # type: ignore[attr-defined] + } _backend = _world.pg_map[not_none(GroupMember.WORLD)][0] _default_pg_init_method = init_method @@ -1959,9 +1962,9 @@ def _new_process_group_helper( if not is_nccl_available(): raise RuntimeError("Distributed package doesn't have NCCL built in") if backend_options is not None: - assert isinstance( - backend_options, ProcessGroupNCCL.Options - ), "Expected backend_options argument to be of type ProcessGroupNCCL.Options" + assert isinstance(backend_options, ProcessGroupNCCL.Options), ( + "Expected backend_options argument to be of type ProcessGroupNCCL.Options" + ) if backend_options._timeout != timeout: warnings.warn( "backend_options._timeout was specified, " @@ -2001,9 +2004,9 @@ def _new_process_group_helper( ) backend_type = ProcessGroup.BackendType.XCCL else: - assert ( - backend_str.upper() in Backend._plugins - ), f"Unknown c10d backend type {backend_str.upper()}" + assert backend_str.upper() in Backend._plugins, ( + f"Unknown c10d backend type {backend_str.upper()}" + ) backend_plugin = Backend._plugins[backend_str.upper()] creator_fn = backend_plugin.creator_fn @@ -2630,8 +2633,10 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]: >>> # xdoctest: +SKIP("no rank") >>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank >>> recv_tensor = torch.randn(2, dtype=torch.float32) - >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size) - >>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size) + >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size) + >>> recv_op = dist.P2POp( + ... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size + ... ) >>> reqs = batch_isend_irecv([send_op, recv_op]) >>> for req in reqs: >>> req.wait() @@ -2758,7 +2763,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): >>> # xdoctest: +SKIP("no rank") >>> # All tensors below are of torch.int64 type. >>> # We have 2 process groups, 2 ranks. - >>> device = torch.device(f'cuda:{rank}') + >>> device = torch.device(f"cuda:{rank}") >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank >>> tensor tensor([1, 2], device='cuda:0') # Rank 0 @@ -2770,7 +2775,9 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): >>> # All tensors below are of torch.cfloat type. >>> # We have 2 process groups, 2 ranks. - >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j) + >>> tensor = torch.tensor( + ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device + ... ) + 2 * rank * (1 + 1j) >>> tensor tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 @@ -3380,9 +3387,9 @@ def recv_object_list( ) rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src) - assert ( - rank_sizes == rank_objects - ), "Mismatch in return ranks for object sizes and objects." + assert rank_sizes == rank_objects, ( + "Mismatch in return ranks for object sizes and objects." + ) # Deserialize objects using their stored sizes. offset = 0 for i, obj_size in enumerate(object_sizes_tensor): @@ -3673,8 +3680,10 @@ def all_gather(tensor_list, tensor, group=None, async_op=False): >>> # xdoctest: +SKIP("need process group init") >>> # All tensors below are of torch.int64 dtype. >>> # We have 2 process groups, 2 ranks. - >>> device = torch.device(f'cuda:{rank}') - >>> tensor_list = [torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)] + >>> device = torch.device(f"cuda:{rank}") + >>> tensor_list = [ + ... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2) + ... ] >>> tensor_list [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0 [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1 @@ -3689,11 +3698,15 @@ def all_gather(tensor_list, tensor, group=None, async_op=False): >>> # All tensors below are of torch.cfloat dtype. >>> # We have 2 process groups, 2 ranks. - >>> tensor_list = [torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)] + >>> tensor_list = [ + ... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2) + ... ] >>> tensor_list [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0 [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1 - >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j) + >>> tensor = torch.tensor( + ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device + ... ) + 2 * rank * (1 + 1j) >>> tensor tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 @@ -3769,7 +3782,7 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal >>> # xdoctest: +SKIP("need process group init") >>> # All tensors below are of torch.int64 dtype and on CUDA devices. >>> # We have two ranks. - >>> device = torch.device(f'cuda:{rank}') + >>> device = torch.device(f"cuda:{rank}") >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank >>> tensor_in tensor([1, 2], device='cuda:0') # Rank 0 @@ -3969,8 +3982,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): ) elif gather_list: raise ValueError( - "Argument ``gather_list`` must NOT be specified " - "on non-destination ranks." + "Argument ``gather_list`` must NOT be specified on non-destination ranks." ) @@ -4141,8 +4153,7 @@ def scatter( else: if scatter_list: raise ValueError( - "Argument ``scatter_list`` must NOT be specified " - "on non-source ranks." + "Argument ``scatter_list`` must NOT be specified on non-source ranks." ) input_tensors = [] output_tensors = [tensor] @@ -4225,7 +4236,7 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F >>> # xdoctest: +SKIP("need process group init") >>> # All tensors below are of torch.int64 dtype and on CUDA devices. >>> # We have two ranks. - >>> device = torch.device(f'cuda:{rank}') + >>> device = torch.device(f"cuda:{rank}") >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device) >>> # Input in concatenation form >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device) @@ -4381,7 +4392,7 @@ def all_to_all_single( >>> # Essentially, it is similar to following operation: >>> scatter_list = list(input.chunk(world_size)) - >>> gather_list = list(output.chunk(world_size)) + >>> gather_list = list(output.chunk(world_size)) >>> for i in range(world_size): >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) @@ -4411,7 +4422,9 @@ def all_to_all_single( >>> # Another example with tensors of torch.cfloat type. - >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j) + >>> input = torch.tensor( + ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat + ... ) + 4 * rank * (1 + 1j) >>> input tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0 tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1 @@ -4510,7 +4523,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False >>> # Essentially, it is similar to following operation: >>> scatter_list = input - >>> gather_list = output + >>> gather_list = output >>> for i in range(world_size): >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i) @@ -4544,7 +4557,9 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3 >>> # Another example with tensors of torch.cfloat type. - >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j) + >>> input = torch.tensor( + ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat + ... ) + 4 * rank * (1 + 1j) >>> input = list(input.chunk(4)) >>> input [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0 @@ -4882,9 +4897,9 @@ def split_group( backend_config = BackendConfig(backend) if pg_options is not None: - assert isinstance( - pg_options, ProcessGroupNCCL.Options - ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" + assert isinstance(pg_options, ProcessGroupNCCL.Options), ( + "Expected pg_options argument to be of type ProcessGroupNCCL.Options" + ) else: # default pg_options same as the parent process group pg_options = parent_backend.options @@ -5086,9 +5101,9 @@ def _new_group_with_tag( if device_id is None: device_id = default_pg.bound_device_id elif default_pg.bound_device_id is not None: - assert ( - device_id == default_pg.bound_device_id - ), "Mismatched bound device between new pg and the default pg." + assert device_id == default_pg.bound_device_id, ( + "Mismatched bound device between new pg and the default pg." + ) default_backend, default_store = _world.pg_map[default_pg] global_rank = default_pg.rank() global_world_size = default_pg.size() @@ -5408,9 +5423,9 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro def _find_or_create_pg_by_ranks_and_tag( tag: str, ranks: list[int], stride: int ) -> ProcessGroup: - assert ( - len(ranks) % stride == 0 - ), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + assert len(ranks) % stride == 0, ( + f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + ) my_rank = get_rank() my_ranks = None diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py index bdd5a3584976..8e47868e2977 100644 --- a/torch/distributed/elastic/control_plane.py +++ b/torch/distributed/elastic/control_plane.py @@ -40,8 +40,9 @@ def worker_main() -> Generator[None, None, None]: def main(): pass - if __name__=="__main__": - main() + + if __name__ == "__main__": + main() """ with ExitStack() as stack: diff --git a/torch/distributed/elastic/events/__init__.py b/torch/distributed/elastic/events/__init__.py index e6c2a271644f..02e158b021a0 100644 --- a/torch/distributed/elastic/events/__init__.py +++ b/torch/distributed/elastic/events/__init__.py @@ -14,7 +14,10 @@ Example of usage: :: from torch.distributed.elastic import events - event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...}) + + event = events.Event( + name="test_event", source=events.EventSource.WORKER, metadata={...} + ) events.get_logging_handler(destination="console").info(event) """ diff --git a/torch/distributed/elastic/metrics/__init__.py b/torch/distributed/elastic/metrics/__init__.py index 4b72dcd7c602..b07671fbac9d 100644 --- a/torch/distributed/elastic/metrics/__init__.py +++ b/torch/distributed/elastic/metrics/__init__.py @@ -52,11 +52,12 @@ The example below measures the latency for the ``calculate()`` function. metrics.configure(metrics.NullMetricsHandler()) metrics.configure(metrics.ConsoleMetricsHandler(), "my_module") + def my_method(): - start = time.time() - calculate() - end = time.time() - metrics.put_metric("calculate_latency", int(end-start), "my_module") + start = time.time() + calculate() + end = time.time() + metrics.put_metric("calculate_latency", int(end - start), "my_module") You may also use the torch.distributed.elastic.metrics.prof` decorator to conveniently and succinctly profile functions @@ -70,15 +71,16 @@ to conveniently and succinctly profile functions metrics.configure(metrics.ConsoleMetricsHandler(), "foobar") metrics.configure(metrics.ConsoleMetricsHandler(), "Bar") + @metrics.prof def foo(): - pass + pass - class Bar(): - @metrics.prof - def baz(): - pass + class Bar: + @metrics.prof + def baz(): + pass ``@metrics.prof`` will publish the following metrics :: @@ -102,8 +104,8 @@ console. import torch.distributed.elastic.metrics as metrics - metrics.configure(metrics.ConsoleMetricHandler(), group = "torchelastic") - metrics.configure(metrics.ConsoleMetricHandler(), group = "my_app") + metrics.configure(metrics.ConsoleMetricHandler(), group="torchelastic") + metrics.configure(metrics.ConsoleMetricHandler(), group="my_app") **Writing a Custom Metric Handler**: @@ -117,13 +119,15 @@ Below is a toy example that prints the metrics to ``stdout`` import torch.distributed.elastic.metrics as metrics + class StdoutMetricHandler(metrics.MetricHandler): - def emit(self, metric_data): - ts = metric_data.timestamp - group = metric_data.group_name - name = metric_data.name - value = metric_data.value - print(f"[{ts}][{group}]: {name}={value}") + def emit(self, metric_data): + ts = metric_data.timestamp + group = metric_data.group_name + name = metric_data.name + value = metric_data.value + print(f"[{ts}][{group}]: {name}={value}") + metrics.configure(StdoutMetricHandler(), group="my_app") diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 551d18aed733..2f4100a461ad 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -123,6 +123,7 @@ def prof(fn=None, group: str = "torchelastic"): def x(): pass + @metrics.prof(group="agent") def y(): pass diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index d0d311d2fb49..fe829a26ce84 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -20,22 +20,23 @@ Usage 1: Launching two trainers as a function from torch.distributed.elastic.multiprocessing import Std, start_processes + def trainer(a, b, c): - pass # train + pass # train # runs two trainers # LOCAL_RANK=0 trainer(1,2,3) # LOCAL_RANK=1 trainer(4,5,6) ctx = start_processes( - name="trainer", - entrypoint=trainer, - args={0: (1,2,3), 1: (4,5,6)}, - envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}}, - log_dir="/tmp/foobar", - redirects=Std.ALL, # write all worker stdout/stderr to a log file - tee={0: Std.ERR}, # tee only local rank 0's stderr to console - ) + name="trainer", + entrypoint=trainer, + args={0: (1, 2, 3), 1: (4, 5, 6)}, + envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}}, + log_dir="/tmp/foobar", + redirects=Std.ALL, # write all worker stdout/stderr to a log file + tee={0: Std.ERR}, # tee only local rank 0's stderr to console + ) # waits for all copies of trainer to finish ctx.wait() diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index ca51d0329fe9..6d899a95d6a7 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -165,9 +165,11 @@ def to_map( Example: :: - to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} - to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} - to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} + to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} + to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} + to_map( + {0: Std.OUT, 1: Std.OUT}, local_world_size=2 + ) # returns: {0: Std.OUT, 1: Std.OUT} """ if isinstance(val_or_map, Std): return dict.fromkeys(range(local_world_size), val_or_map) @@ -304,7 +306,9 @@ class DefaultLogsSpecs(LogsSpecs): if not self._run_log_dir: self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id) - attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload] + attempt_log_dir = os.path.join( + self._run_log_dir, f"attempt_{restart_count}" + ) # type: ignore[call-overload] shutil.rmtree(attempt_log_dir, ignore_errors=True) os.makedirs(attempt_log_dir) @@ -868,9 +872,7 @@ class SubprocessContext(PContext): if result.is_failed(): first_failure = min(result.failures.values(), key=lambda f: f.timestamp) logger.error( - "failed (exitcode: %s)" - " local_rank: %s (pid: %s)" - " of binary: %s", + "failed (exitcode: %s) local_rank: %s (pid: %s) of binary: %s", first_failure.exitcode, first_failure.local_rank, first_failure.pid, diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py index 34b22bbd8a2e..57e445a3d02a 100644 --- a/torch/distributed/elastic/multiprocessing/errors/__init__.py +++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -318,14 +318,14 @@ def record( error_handler = get_error_handler() error_handler.initialize() try: - foobar() + foobar() except ChildFailedError as e: - _, failure = e.get_first_failure() - error_handler.dump_error_file(failure.error_file, failure.exitcode) - raise + _, failure = e.get_first_failure() + error_handler.dump_error_file(failure.error_file, failure.exitcode) + raise except Exception as e: - error_handler.record_exception(e) - raise + error_handler.record_exception(e) + raise .. important:: use this decorator once per process at the top level method, typically this is the main method. @@ -338,8 +338,9 @@ def record( def main(): pass - if __name__=="__main__": - main() + + if __name__ == "__main__": + main() """ if not error_handler: diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py index 22ec0c9a0f67..0766df8e5f3a 100644 --- a/torch/distributed/elastic/rendezvous/__init__.py +++ b/torch/distributed/elastic/rendezvous/__init__.py @@ -120,11 +120,7 @@ of the following implementations that come with PyTorch: backend = C10dRendezvousBackend(store, "my_run_id") rdzv_handler = DynamicRendezvousHandler.from_backend( - run_id="my_run_id", - store=store, - backend=backend, - min_nodes=2, - max_nodes=4 + run_id="my_run_id", store=store, backend=backend, min_nodes=2, max_nodes=4 ) """ diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py index 58b978f5bb70..be0d6e28536f 100644 --- a/torch/distributed/elastic/rendezvous/api.py +++ b/torch/distributed/elastic/rendezvous/api.py @@ -89,8 +89,14 @@ class RendezvousStoreInfo: addr = local_addr or socket.getfqdn() # When TCPStore is not shared, we fallback to get_free_port. port = server_port or get_free_port() - store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type] - store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type] + store.set( + RendezvousStoreInfo.MASTER_ADDR_KEY, + addr.encode(encoding="UTF-8"), # type: ignore[arg-type] + ) + store.set( + RendezvousStoreInfo.MASTER_PORT_KEY, + str(port).encode(encoding="UTF-8"), # type: ignore[arg-type] + ) addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8") port = int( diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index ee67df0e56bb..6b049423ffc6 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -413,9 +413,9 @@ class EtcdRendezvous: active_version = self.wait_for_peers(expected_version) state = json.loads(active_version.value) - assert ( - state["version"] == expected_version - ), "Logic error: failed to observe version mismatch" + assert state["version"] == expected_version, ( + "Logic error: failed to observe version mismatch" + ) return self.confirm_phase(expected_version, this_rank) @@ -533,9 +533,9 @@ class EtcdRendezvous: "Rendezvous version changed. Must try join the new one." ) - assert ( - len(state["participants"]) < self._num_max_workers - ), "Logic error: joinable rendezvous should always have space left" + assert len(state["participants"]) < self._num_max_workers, ( + "Logic error: joinable rendezvous should always have space left" + ) this_rank = len(state["participants"]) state["participants"].append(this_rank) diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py index d038ab95eaba..75f0d16f7d19 100644 --- a/torch/distributed/elastic/rendezvous/registry.py +++ b/torch/distributed/elastic/rendezvous/registry.py @@ -86,11 +86,15 @@ def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: from torch.distributed.elastic.rendezvous import rendezvous_handler_registry from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler + def create_my_rdzv(params: RendezvousParameters): - return MyCustomRdzv(params) + return MyCustomRdzv(params) + rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv) - my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters) + my_rdzv_handler = get_rendezvous_handler( + "my_rdzv_backend_name", RendezvousParameters + ) """ return handler_registry.create_handler(params) diff --git a/torch/distributed/elastic/utils/store.py b/torch/distributed/elastic/utils/store.py index 9fa97d5f960c..0afe82c46d89 100644 --- a/torch/distributed/elastic/utils/store.py +++ b/torch/distributed/elastic/utils/store.py @@ -57,10 +57,10 @@ def get_all(store, rank: int, prefix: str, world_size: int): :: - values = get_all(store, 'torchelastic/data', 3) - value1 = values[0] # retrieves the data for key torchelastic/data0 - value2 = values[1] # retrieves the data for key torchelastic/data1 - value3 = values[2] # retrieves the data for key torchelastic/data2 + values = get_all(store, "torchelastic/data", 3) + value1 = values[0] # retrieves the data for key torchelastic/data0 + value2 = values[1] # retrieves the data for key torchelastic/data1 + value3 = values[2] # retrieves the data for key torchelastic/data2 """ data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)]) diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index 63ace36da62b..c8729fb63600 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -2,6 +2,7 @@ """ This file includes private common utilities for FSDP. """ + import logging import traceback import warnings @@ -200,9 +201,9 @@ def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamH # handles, meaning no entry in `_fully_sharded_module_to_handles` if state._handle is None: return None - assert ( - module in state._fully_sharded_module_to_handle - ), f"Expects a fully sharded module but got {module} on rank {state.rank}" + assert module in state._fully_sharded_module_to_handle, ( + f"Expects a fully sharded module but got {module} on rank {state.rank}" + ) return state._fully_sharded_module_to_handle[module] else: # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance. @@ -255,9 +256,9 @@ def _named_parameters_with_duplicates( This API is required as some modules overwrite `named_parameters()` but do not support `remove_duplicate`. """ - assert ( - "remove_duplicate" not in kwargs - ), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument." + assert "remove_duplicate" not in kwargs, ( + "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument." + ) kwargs["remove_duplicate"] = False try: ret = list(module.named_parameters(**kwargs)) diff --git a/torch/distributed/fsdp/_exec_order_utils.py b/torch/distributed/fsdp/_exec_order_utils.py index b19e919de296..519ce39b1678 100644 --- a/torch/distributed/fsdp/_exec_order_utils.py +++ b/torch/distributed/fsdp/_exec_order_utils.py @@ -190,9 +190,9 @@ class _ExecOrderData: return if self.is_first_iter: msg_prefix = "Forward order differs across ranks:" - optional_local_indices: tuple[ - Optional[int], ... - ] = self._get_handle_indices(handle) + optional_local_indices: tuple[Optional[int], ...] = ( + self._get_handle_indices(handle) + ) device = handle.device # guaranteed to be non-CPU num_valid_indices = sum( (index is not None) for index in optional_local_indices @@ -250,8 +250,7 @@ class _ExecOrderData: ( rank, world_indices[ - rank - * num_valid_indices : (rank + 1) + rank * num_valid_indices : (rank + 1) * num_valid_indices ], ) diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 85d85c3f2e0e..1f5f0b217600 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -586,7 +586,10 @@ class FlatParamHandle: ) self._fsdp_extension = fsdp_extension self._init_flat_param_and_metadata( - params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type] + params, + fully_sharded_module, + self._aligned_numel, + use_orig_params, # type: ignore[arg-type] ) self._use_unsharded_views(as_params=False) @@ -978,9 +981,9 @@ class FlatParamHandle: shard_param_infos = self._get_shard_metadata( unsharded_start_idx, unsharded_end_idx ) - assert ( - len(shard_param_infos) == flat_param._num_params - ), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}" + assert len(shard_param_infos) == flat_param._num_params, ( + f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}" + ) flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined] flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined] @@ -996,9 +999,9 @@ class FlatParamHandle: unsharded flat parameter specifying the shard. """ flat_param_offsets = self._get_flat_param_offsets() - assert len(flat_param_offsets) == len( - self.flat_param._numels_with_padding - ), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}" + assert len(flat_param_offsets) == len(self.flat_param._numels_with_padding), ( + f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}" + ) shard_param_infos: list[_ShardParamInfo] = [] sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1 # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices @@ -1075,9 +1078,9 @@ class FlatParamHandle: else: chunk = chunks[rank] numel_to_pad = chunks[0].numel() - chunk.numel() - assert ( - numel_to_pad >= 0 - ), "Chunk's size should be at most the first chunk's size" + assert numel_to_pad >= 0, ( + "Chunk's size should be at most the first chunk's size" + ) return chunk, numel_to_pad @staticmethod @@ -1302,7 +1305,8 @@ class FlatParamHandle: self._check_low_precision_shard() flat_param = self.flat_param _alloc_storage( - flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined] + flat_param._mp_shard, + flat_param._local_shard.size(), # type: ignore[attr-defined] ) # `copy_()` implicitly casts to the low precision flat_param._mp_shard.copy_( # type: ignore[attr-defined] @@ -1498,7 +1502,8 @@ class FlatParamHandle: # default stream suffices since the default stream waits for the # unshard stream. _no_dispatch_record_stream( - self.flat_param._mp_shard, self._device_handle.current_stream() # type: ignore[attr-defined] + self.flat_param._mp_shard, + self._device_handle.current_stream(), # type: ignore[attr-defined] ) _free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined] @@ -1593,8 +1598,7 @@ class FlatParamHandle: f"but got {flat_param.grad.device}", ) prev_iter_synced_gradients = ( - flat_param.grad.size() - == flat_param._local_shard.size() # type: ignore[attr-defined] + flat_param.grad.size() == flat_param._local_shard.size() # type: ignore[attr-defined] ) if prev_iter_synced_gradients: # TODO (awgu): Gradient accumulation outside `no_sync()` @@ -1668,8 +1672,7 @@ class FlatParamHandle: cast_grad_to_param_dtype_if_needed(flat_param) else: _p_assert( - not self.uses_sharded_strategy - or not flat_param._post_backward_called, # type: ignore[attr-defined] + not self.uses_sharded_strategy or not flat_param._post_backward_called, # type: ignore[attr-defined] "All sharded parameters that received a gradient in the " "post-backward should use `_saved_grad_shard`", ) @@ -2504,7 +2507,8 @@ class FlatParamHandle: """Return the FQNs of the parameters present in this rank's shard.""" fqns_in_shard: list[str] = [] for fqn, shard_param_info in zip( - self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined] + self.flat_param._fqns, + self.flat_param._shard_param_infos, # type: ignore[attr-defined] ): if shard_param_info.in_shard: fqns_in_shard.append(fqn) @@ -2694,7 +2698,7 @@ def _safe_setattr_tensor_or_param( def _convert_to_params( - tensors: list[Union[torch.Tensor, nn.Parameter]] + tensors: list[Union[torch.Tensor, nn.Parameter]], ) -> list[nn.Parameter]: return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors] diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index 676c3a1c31d3..e9a027393057 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -374,9 +374,9 @@ def foreach_reduce( for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)): if (shard_dim := fsdp_param.fsdp_placement.dim) == 0: continue - assert ( - unsharded_grad.size(shard_dim) % world_size == 0 - ), f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}" + assert unsharded_grad.size(shard_dim) % world_size == 0, ( + f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}" + ) chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim) unsharded_grads[i] = torch.cat(chunks, dim=0) padded_unsharded_sizes = tuple( diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py index 1f6c3784968c..fdcf32e22a33 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -26,9 +26,9 @@ if torch._running_with_deploy(): else: def detect_compiled_autograd(): - assert ( - not torch.compiler.is_compiling() - ), "`detect_compiled_autograd()` is designed to be called in eager mode" + assert not torch.compiler.is_compiling(), ( + "`detect_compiled_autograd()` is designed to be called in eager mode" + ) global _compiled_autograd_enabled import torch._dynamo.compiled_autograd as ca diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index e92b59171e24..4bb882cb21c4 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -304,9 +304,9 @@ class FSDPParam: f"FSDP only supports 1D TP, not {self._tp_spec.placements}" ) split_factor = self._tp_spec.num_shards_map[shard_dim] - assert ( - 2 <= self._spmd_mesh.ndim <= 3 - ), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." + assert 2 <= self._spmd_mesh.ndim <= 3, ( + f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." + ) self._spmd_placements: tuple[Placement, ...] dp_shard_tp_placement = ( ( @@ -520,8 +520,9 @@ class FSDPParam: unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) if hasattr(self, "_unsharded_param"): assert compiled_autograd_enabled() - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - self._unsharded_param + with ( + torch.no_grad(), + torch.autograd._unsafe_preserve_version_counter(self._unsharded_param), ): # NOTE: Under compile, if an unsharded param goes through # resize_(full) -> copy_ -> resize_(0) pattern, we will remove those @@ -785,9 +786,9 @@ class FSDPParam: assert isinstance(grad, DTensor), f"{type(grad)}" placements = self._tp_spec.placements if placements != grad.placements: - assert len(self._tp_spec.placements) == len( - grad.placements - ), f"{self._tp_spec=} {grad.placements=}" + assert len(self._tp_spec.placements) == len(grad.placements), ( + f"{self._tp_spec=} {grad.placements=}" + ) grad = grad.redistribute(placements=placements) grad = grad._local_tensor return grad @@ -846,9 +847,9 @@ class FSDPParam: shard_dim = self.fsdp_placement.dim length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 if local_tensor.size() != padded_sharded_size: - assert ( - shard_dim == 0 - ), f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}" + assert shard_dim == 0, ( + f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}" + ) padded_local_tensor = local_tensor.new_zeros(padded_sharded_size) padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_( local_tensor diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index cb9d586b7693..e149005ffc2c 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -424,9 +424,9 @@ class FSDPParamGroup: if all_reduce_pg is None and self._all_reduce_hook_stream is not None: # this means the native HSDP is not enabled, # but user may want to have a custom HSDP setup - assert ( - self._all_reduce_hook is not None - ), "all reduce hook stream is specified but hook itself is missing." + assert self._all_reduce_hook is not None, ( + "all reduce hook stream is specified but hook itself is missing." + ) all_reduce_stream = self._all_reduce_hook_stream else: all_reduce_stream = self.comm_ctx.all_reduce_stream @@ -513,9 +513,10 @@ class FSDPParamGroup: else: raise ValueError(f"Unknown pass type: {pass_type}") target_fqn = target_fsdp_param_group._module_fqn - with record_function( - f"FSDP::{pass_type}_prefetch for {target_fqn}" - ), target_fsdp_param_group.use_training_state(training_state): + with ( + record_function(f"FSDP::{pass_type}_prefetch for {target_fqn}"), + target_fsdp_param_group.use_training_state(training_state), + ): async_op = target_fsdp_param_group.unshard_async_op target_fsdp_param_group.unshard(async_op) @@ -592,9 +593,9 @@ class FSDPParamGroup: def _register_state_dict_hooks(self) -> None: num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle) num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle) - assert ( - num_pre_save_hooks == num_pre_load_hooks - ), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}" + assert num_pre_save_hooks == num_pre_load_hooks, ( + f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}" + ) if num_pre_save_hooks > 0: return # already registered modules_with_fsdp_params: set[nn.Module] = { @@ -605,12 +606,12 @@ class FSDPParamGroup: self._to_sharded() for module in modules_with_fsdp_params: - self._module_to_pre_save_state_dict_hook_handle[ - module - ] = module.register_state_dict_pre_hook(to_sharded_hook) - self._module_to_pre_load_state_dict_hook_handle[ - module - ] = module._register_load_state_dict_pre_hook(to_sharded_hook) + self._module_to_pre_save_state_dict_hook_handle[module] = ( + module.register_state_dict_pre_hook(to_sharded_hook) + ) + self._module_to_pre_load_state_dict_hook_handle[module] = ( + module._register_load_state_dict_pre_hook(to_sharded_hook) + ) # Properties # @property diff --git a/torch/distributed/fsdp/_fully_shard/_fully_shard.py b/torch/distributed/fsdp/_fully_shard/_fully_shard.py index 4d491cc79b5d..a5e94dda909d 100644 --- a/torch/distributed/fsdp/_fully_shard/_fully_shard.py +++ b/torch/distributed/fsdp/_fully_shard/_fully_shard.py @@ -60,8 +60,7 @@ def fully_shard( mp_policy: MixedPrecisionPolicy = ..., offload_policy: OffloadPolicy = ..., ignored_params: Optional[set[nn.Parameter]] = ..., -) -> FSDPModule: - ... +) -> FSDPModule: ... @overload @@ -74,8 +73,7 @@ def fully_shard( mp_policy: MixedPrecisionPolicy = ..., offload_policy: OffloadPolicy = ..., ignored_params: Optional[set[nn.Parameter]] = ..., -) -> list[FSDPModule]: - ... +) -> list[FSDPModule]: ... # The decorator adds a state object to `module` that can be accessed via diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 952daa0be2de..feaf8b882963 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -243,9 +243,9 @@ def _init_inter_node_process_group( if local_rank == my_local_rank: inter_node_pg = grp - assert ( - inter_node_pg is not None - ), f"{my_local_rank} expected to assign inter-node pg, but did not" + assert inter_node_pg is not None, ( + f"{my_local_rank} expected to assign inter-node pg, but did not" + ) return inter_node_pg diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index c6b9c3d1141b..ea1af6af0b23 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -145,9 +145,9 @@ def _unflatten_optim_state( dict will need to map these entries using the proper unflattened parameter IDs. """ - assert ( - not shard_state or to_save - ), "If ``shard_state`` is True, ``to_save`` has to be True." + assert not shard_state or to_save, ( + "If ``shard_state`` is True, ``to_save`` has to be True." + ) consolidated_state = _communicate_optim_state( fsdp_param_info, flat_param_state, @@ -218,9 +218,9 @@ def _communicate_optim_state( ): tensor_state[state_name] = value continue - assert ( - fsdp_state.compute_device is not None - ), "compute_device has not been initialized" + assert fsdp_state.compute_device is not None, ( + "compute_device has not been initialized" + ) if value.device.type != fsdp_state.compute_device.type: value = value.to(fsdp_state.compute_device) # Assume that positive-dimension tensor optimizer state @@ -394,7 +394,10 @@ def _shard_orig_param_state( and value.dim() > 0 and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD ): - value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator] + value = value.flatten()[ + intra_param_start_idx : intra_param_end_idx # type: ignore[operator] + + 1 + ].clone() new_optim_state[state_name] = value return new_optim_state @@ -489,9 +492,9 @@ def _flatten_optim_state_dict( if flat_state: flat_osd_state[key] = flat_state elif use_orig_params: - assert ( - len(fqns) == 1 - ), f"use_orig_params is True but there are multiple FQNs, {fqns}." + assert len(fqns) == 1, ( + f"use_orig_params is True but there are multiple FQNs, {fqns}." + ) if optim is not None: # NamedOptimizer or KeyedOptimizer case. state = optim.state.get(param, None) # type: ignore[call-overload] if state is not None: @@ -570,14 +573,13 @@ def _flatten_optim_state( flat_param = handle.flat_param num_unflat_params = len(unflat_param_names) assert num_unflat_params > 0, ( - "Expects at least one unflattened parameter corresponding to the " - "flat parameter" + "Expects at least one unflattened parameter corresponding to the flat parameter" ) unflat_param_shapes = flat_param._shapes num_unflat_param_shapes = len(unflat_param_shapes) - assert ( - num_unflat_params == num_unflat_param_shapes - ), f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}" + assert num_unflat_params == num_unflat_param_shapes, ( + f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}" + ) # Check if these unflattened parameters have any optimizer state has_state = [ @@ -759,8 +761,7 @@ def _flatten_tensor_optim_state( flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel) flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined] assert flat_tensor.shape == flat_param_shape, ( - f"tensor optim state: {flat_tensor.shape} " - f"flat parameter: {flat_param_shape}" + f"tensor optim state: {flat_tensor.shape} flat parameter: {flat_param_shape}" ) return flat_tensor @@ -1065,9 +1066,9 @@ def _get_param_key_to_param( """ clean_fqn_to_curr_fqn: dict[str, str] = {} if is_named_optimizer: - assert ( - param_to_fqns is not None and flat_param_to_fqn is not None - ), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None." + assert param_to_fqns is not None and flat_param_to_fqn is not None, ( + "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None." + ) assert model is not None for key, _ in _named_parameters_with_duplicates(model): clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key @@ -1150,9 +1151,9 @@ def _check_missing_keys_on_rank( continue param_key = optim_state_key_to_param_key[r0_optim_state_key] if isinstance(param_key, int): - assert param_key >= 0 and param_key < len( - param_key_to_param - ), "Check the `param_key_to_param` construction" + assert param_key >= 0 and param_key < len(param_key_to_param), ( + "Check the `param_key_to_param` construction" + ) # We cannot use FSDPState.compute_device as this API is a global view. device = _get_pg_default_device(group) num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device) diff --git a/torch/distributed/fsdp/_shard_utils.py b/torch/distributed/fsdp/_shard_utils.py index 70f80582d7f3..037bef9be3b3 100644 --- a/torch/distributed/fsdp/_shard_utils.py +++ b/torch/distributed/fsdp/_shard_utils.py @@ -121,9 +121,9 @@ def _all_gather_dtensor( """ All gather a DTensor in its sharded dimension and return the local tensor. """ - assert ( - root_mesh == tensor.device_mesh - ), "The device mesh of a tensor should be a root mesh." + assert root_mesh == tensor.device_mesh, ( + "The device mesh of a tensor should be a root mesh." + ) placements = list(copy.deepcopy(tensor.placements)) # FSDP placements: [Shard(0)] -> [Replicate()] diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 9d3c3b35259b..0d3285255df2 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -466,9 +466,9 @@ def _local_pre_load_state_dict_hook( ) return load_tensor = state_dict[fqn] - assert isinstance( - load_tensor, ShardedTensor - ), "Tensors in local_state_dict should be ShardedTensor." + assert isinstance(load_tensor, ShardedTensor), ( + "Tensors in local_state_dict should be ShardedTensor." + ) # Convert the ShardedTensor to a Tensor. flat_param = _module_handle(fsdp_state, module).flat_param diff --git a/torch/distributed/fsdp/_trace_utils.py b/torch/distributed/fsdp/_trace_utils.py index fcd09f6ce9f3..22cde2abc966 100644 --- a/torch/distributed/fsdp/_trace_utils.py +++ b/torch/distributed/fsdp/_trace_utils.py @@ -143,9 +143,9 @@ class _ExecOrderTracer: named_params = list(module.named_parameters()) curr_module = exec_info.curr_module if named_params: - assert ( - curr_module in exec_info.module_to_param_usage_infos - ), "The current module should have already been processed by a patched `call_module`" + assert curr_module in exec_info.module_to_param_usage_infos, ( + "The current module should have already been processed by a patched `call_module`" + ) exec_info.module_to_param_usage_infos[exec_info.curr_module].append( _ParamUsageInfo(module, named_params) ) diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py index ad495c73426d..1876c4a44431 100644 --- a/torch/distributed/fsdp/_unshard_param_utils.py +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -185,9 +185,9 @@ def _unshard_fsdp_state_params( yield return - assert ( - handle._training_state == HandleTrainingState.IDLE - ), f"Expects the handle training to be IDLE but got {handle._training_state}" + assert handle._training_state == HandleTrainingState.IDLE, ( + f"Expects the handle training to be IDLE but got {handle._training_state}" + ) handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS diff --git a/torch/distributed/fsdp/api.py b/torch/distributed/fsdp/api.py index 7282fbcd7b59..17ed0483f1c2 100644 --- a/torch/distributed/fsdp/api.py +++ b/torch/distributed/fsdp/api.py @@ -306,16 +306,21 @@ class FullStateDictConfig(StateDictConfig): >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> state = fsdp.state_dict() - >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. + >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: - >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP + >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP >>> if dist.get_rank() == 0: - >>> # Load checkpoint only on rank 0 to avoid memory redundancy + >>> # Load checkpoint only on rank 0 to avoid memory redundancy >>> state_dict = torch.load("my_checkpoint.pt") >>> model.load_state_dict(state_dict) >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument >>> # communicates loaded checkpoint states from rank 0 to rest of the world. - >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) + >>> fsdp = FSDP( + ... model, + ... device_id=torch.cuda.current_device(), + ... auto_wrap_policy=..., + ... sync_module_states=True, + ... ) >>> # After this point, all ranks have FSDP model with loaded checkpoint. Attributes: diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index d1de71e5cc91..0eafd26e31f9 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -723,9 +723,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): if prev_state_dict_type is None: prev_state_dict_type = submodule._state_dict_type else: - assert ( - prev_state_dict_type == submodule._state_dict_type - ), "All FSDP modules should have the same state_dict_type." + assert prev_state_dict_type == submodule._state_dict_type, ( + "All FSDP modules should have the same state_dict_type." + ) if prev_state_dict_config is None: prev_state_dict_config = submodule._state_dict_config else: @@ -738,7 +738,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): assert isinstance( submodule._optim_state_dict_config, type(prev_optim_state_dict_config), - ), "All FSDP modules must have the same type of optim_state_dict_config." + ), ( + "All FSDP modules must have the same type of optim_state_dict_config." + ) submodule._state_dict_type = state_dict_type submodule._state_dict_config = state_dict_config @@ -2153,9 +2155,9 @@ def _get_param_to_fqn( """ param_to_param_names = _get_param_to_fqns(model) for param_names in param_to_param_names.values(): - assert ( - len(param_names) > 0 - ), "`_get_param_to_fqns()` should not construct empty lists" + assert len(param_names) > 0, ( + "`_get_param_to_fqns()` should not construct empty lists" + ) if len(param_names) > 1: raise RuntimeError( "Each parameter should only map to one parameter name but got " diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index d19cb720543c..b1611130c9e0 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -112,20 +112,16 @@ class ShardedGradScaler(GradScaler): self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) @overload - def scale(self, outputs: torch.Tensor) -> torch.Tensor: - ... + def scale(self, outputs: torch.Tensor) -> torch.Tensor: ... @overload - def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: - ... + def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ... @overload - def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: - ... + def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: ... @overload - def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: - ... + def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ... def scale( self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]] @@ -323,8 +319,10 @@ class ShardedGradScaler(GradScaler): if isinstance(new_scale, float): self._scale.fill_(new_scale) # type: ignore[union-attr] else: - reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ + reason = ( + "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ torch.FloatTensor with requires_grad=False." + ) assert new_scale.device.type == self._device, reason assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index 55d6b3bc58ff..3bc8cb0ae380 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -61,9 +61,9 @@ def _post_order_apply( "Non-root modules should have their module name set but got " f"an empty module name for {module}" ) - assert isinstance( - optional_module, nn.Module - ), f"fn should return None or an nn.Module but got {optional_module}" + assert isinstance(optional_module, nn.Module), ( + f"fn should return None or an nn.Module but got {optional_module}" + ) setattr(parent_module, module_name, optional_module) _post_order_apply_inner(root_module, "", None) @@ -575,9 +575,9 @@ class _ConfigAutoWrap: ) _ConfigAutoWrap.in_autowrap_context = True # Get and save the wrapper cls for the context. - assert ( - "wrapper_cls" in kwargs.keys() - ), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." + assert "wrapper_cls" in kwargs.keys(), ( + "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." + ) _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) del kwargs["wrapper_cls"] # Save the rest. diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index a9e35c36db7f..ad3307c13303 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -183,8 +183,7 @@ def parse_args(args): def launch(args): if args.no_python and not args.use_env: raise ValueError( - "When using the '--no-python' flag," - " you must also set the '--use-env' flag." + "When using the '--no-python' flag, you must also set the '--use-env' flag." ) run(args) diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 40fec71787ae..e08b9cad1b03 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -39,7 +39,10 @@ _REMOTE_MODULE_PICKLED_ATTRIBUTES = ( "module_rref", ) -_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) # type: ignore[misc] +_SerializedRemoteModule = collections.namedtuple( # type: ignore[misc] + "_SerializedRemoteModule", + _REMOTE_MODULE_PICKLED_ATTRIBUTES, +) # These attributes are mostly from RemoteModule's parent class and are intentionally not pickled. # A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES diff --git a/torch/distributed/nn/jit/instantiator.py b/torch/distributed/nn/jit/instantiator.py index 22eb3bb3d4c1..9465eb036daa 100644 --- a/torch/distributed/nn/jit/instantiator.py +++ b/torch/distributed/nn/jit/instantiator.py @@ -26,15 +26,15 @@ sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH) def get_arg_return_types_from_interface(module_interface): - assert getattr( - module_interface, "__torch_script_interface__", False - ), "Expect a TorchScript class interface decorated by @torch.jit.interface." + assert getattr(module_interface, "__torch_script_interface__", False), ( + "Expect a TorchScript class interface decorated by @torch.jit.interface." + ) qualified_name = torch._jit_internal._qualified_name(module_interface) cu = torch.jit._state._python_cu module_interface_c = cu.get_interface(qualified_name) - assert ( - "forward" in module_interface_c.getMethodNames() - ), f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}" + assert "forward" in module_interface_c.getMethodNames(), ( + f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}" + ) method_schema = module_interface_c.getMethod("forward") arg_str_list = [] diff --git a/torch/distributed/optim/__init__.py b/torch/distributed/optim/__init__.py index d8fee468504d..faac68bb6329 100644 --- a/torch/distributed/optim/__init__.py +++ b/torch/distributed/optim/__init__.py @@ -5,6 +5,7 @@ optimizer locally on the workers where the parameters live. The distributed optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to apply the gradients on each worker. """ + import warnings import torch diff --git a/torch/distributed/optim/apply_optimizer_in_backward.py b/torch/distributed/optim/apply_optimizer_in_backward.py index 741fe350121b..1ff9854793df 100644 --- a/torch/distributed/optim/apply_optimizer_in_backward.py +++ b/torch/distributed/optim/apply_optimizer_in_backward.py @@ -44,10 +44,10 @@ def _apply_optimizer_in_backward( param_1 = next(params_generator) remainder_params = list(params_generator) - apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02}) - apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04}) + apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": 0.02}) + apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": 0.04}) - model(...).sum().backward() # after backward, parameters will already + model(...).sum().backward() # after backward, parameters will already # have their registered optimizer(s) applied. """ @@ -111,7 +111,7 @@ def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Opt List[torch.optim.Optimizer]: the in-backward optimizers. Example:: - _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01}) + _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01}) optims = _get_optimizers_in_backward(model) """ optims: list[torch.optim.Optimizer] = [] diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index dbbd2ac97131..c8be46e6d155 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -147,12 +147,10 @@ class _NamedOptimizer(optim.Optimizer): return self._post_state_dict({"state": ret_state, "param_groups": ret_groups}) @overload - def step(self, closure: None = ...) -> None: - ... + def step(self, closure: None = ...) -> None: ... @overload - def step(self, closure: Callable[[], float]) -> float: - ... + def step(self, closure: Callable[[], float]) -> float: ... def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """ diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index db3cd24297b1..10ec6ae2eb1d 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. r"""Zero Redundancy Optimizer.""" + import collections import copy import enum @@ -262,9 +263,9 @@ class _OverlapInfo: meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles`` in preparation for the next iteration. """ - assert ( - len(self.broadcast_handles) == self.num_bucket_assignments - ), f"Missing at least one broadcast handle on rank {dist.get_rank()}" + assert len(self.broadcast_handles) == self.num_bucket_assignments, ( + f"Missing at least one broadcast handle on rank {dist.get_rank()}" + ) _ = [x.wait() for x in self.broadcast_handles] self.broadcast_handles.clear() @@ -909,9 +910,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): params_per_rank = overlap_info.params_per_rank offsets = overlap_info.offsets - self._bucket_assignments_per_rank_cache[assigned_rank][ - bucket_index - ] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset) + self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = ( + _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset) + ) if self.global_rank == assigned_rank: offsets[bucket_index] = len(params_per_rank[assigned_rank]) params_per_rank[assigned_rank].extend(bucket_params) @@ -927,9 +928,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): mapping bucket indices to :class:`_DDPBucketAssignment` s for each rank. """ - assert ( - self._overlap_with_ddp - ), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" + assert self._overlap_with_ddp, ( + "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" + ) if len(self._bucket_assignments_per_rank_cache) > 0: return self._bucket_assignments_per_rank_cache @@ -1076,9 +1077,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): "Specifying `gradients` should not " "be used when `overlap_with_ddp=False`" ) - assert ( - closure is None - ), "`closure` is not supported when using a local functional optimizer" + assert closure is None, ( + "`closure` is not supported when using a local functional optimizer" + ) loss = self.optim.step(gradients=gradients) # Sync any updated attributes in the local optimizer to the exposed @@ -1221,9 +1222,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): for rank, local_state_dict in enumerate(self._all_state_dicts): local_param_groups = local_state_dict["param_groups"] global_param_groups = self._partition_parameters()[rank] - assert len(local_param_groups) == len( - global_param_groups - ), "Mismatch between number of local and global parameter groups" + assert len(local_param_groups) == len(global_param_groups), ( + "Mismatch between number of local and global parameter groups" + ) for local_param_group, global_param_group in zip( local_param_groups, global_param_groups @@ -1233,9 +1234,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): local_param_indices = local_param_group["params"] global_params = global_param_group["params"] - assert len(local_param_indices) == len( - global_params - ), "Mismatch between number of local and global parameters in parameter group" + assert len(local_param_indices) == len(global_params), ( + "Mismatch between number of local and global parameters in parameter group" + ) for local_param_index, global_param in zip( local_param_indices, global_params ): @@ -1268,9 +1269,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): dst_param_groups (list[dict]): parameter groups giving the attribute settings to set. """ - assert len(src_param_groups) == len( - dst_param_groups - ), "Mismatch between number of source and destination parameter groups" + assert len(src_param_groups) == len(dst_param_groups), ( + "Mismatch between number of source and destination parameter groups" + ) for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups): # Sync all attributes except the parameters for attr in filter(lambda x: x != "params", src_param_group.keys()): @@ -1479,9 +1480,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): The local optimizer is saved in ``self.optim``. """ - assert ( - self._optim_constructor is not None - ), "The local optimizer class has not been set" + assert self._optim_constructor is not None, ( + "The local optimizer class has not been set" + ) param_groups = self._partition_parameters()[self.rank] # `overlap_with_ddp=True` requires a local functional optimizer @@ -1508,7 +1509,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): "error due to an empty parameter list", self._optim_constructor, ) - self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef] + self.optim: Any = self._optim_constructor( + params, **self._optim_defaults + ) # type: ignore[no-redef] # Log information about the DDP and ZeRO bucketing if dist.get_debug_level() != dist.DebugLevel.OFF: @@ -1531,7 +1534,9 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): else: # NOTE: Passing `param_groups` into the local optimizer constructor # bypasses the empty parameter list check - self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults) # type: ignore[no-redef] + self.optim: Optimizer = self._optim_constructor( + param_groups, **self._optim_defaults + ) # type: ignore[no-redef] # TODO: Manually add `self.param_groups` if using a functional # optimizer; remove this if/when the functional optimizers support diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 5bffdccf9342..416965e80ba3 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -123,12 +123,11 @@ def _insert_stage_symbolic_backward( # getitem calls. If we have a target other than getitem in this # (forward-only) code, there is a bug. assert node.target == operator.getitem, ( - "Found non-getitem call in forward pass. " - "Please report a bug to PiPPy" + "Found non-getitem call in forward pass. Please report a bug to PiPPy" + ) + assert len(node.args) == 2, ( + "Found malformed getitem call. Please report a bug to PiPPy" ) - assert ( - len(node.args) == 2 - ), "Found malformed getitem call. Please report a bug to PiPPy" indexed_value, node_idx = tuple(node.args) # indexed_value is a collection that we are indexing into. It could @@ -249,8 +248,8 @@ class LossWrapper(torch.nn.Module): targets value into the loss function, and get and return the loss value, which will be backpropagated by PiPPy. The above class would then be instantiated like:: - model = ... # instantiate the model - loss_fn = torch.nn.MSELoss() # for the sake of demonstration + model = ... # instantiate the model + loss_fn = torch.nn.MSELoss() # for the sake of demonstration wrapper = MyModelWrapper(model, loss_fn) pipe = Pipe.from_tracing(wrapper, ...) @@ -818,9 +817,9 @@ class Pipe(torch.nn.Module): # Get submodule callee = root.get_submodule(callee_name) - assert not hasattr( - callee, param_fqn - ), f"Module {callee_name} already has a parameter named {param_fqn}" + assert not hasattr(callee, param_fqn), ( + f"Module {callee_name} already has a parameter named {param_fqn}" + ) # Assign the parameter to the submodule if is_buffer: @@ -979,7 +978,7 @@ class Pipe(torch.nn.Module): else: logger.debug("Pipeline is in inference mode, backward pass not generated") - logger.debug("Full pipe model:\n" f"{split}") # noqa: G004 + logger.debug(f"Full pipe model:\n{split}") # noqa: G004 return Pipe( split, @@ -1184,7 +1183,7 @@ def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]): except AttributeError as e: raise AttributeError( f"Specified target {qualname} referenced " - f'nonexistent module {".".join(atoms[: i + 1])}' + f"nonexistent module {'.'.join(atoms[: i + 1])}" ) from e mod_to_wrap = getattr(predecessor_module, atoms[-1]) diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index a31ee53206ab..4269375a1c6f 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -306,17 +306,17 @@ def stage_backward( if isinstance(output_val, torch.Tensor): if not output_val.requires_grad and output_val.grad_fn is None: return - assert isinstance( - grad_val, (torch.Tensor, type(None)) - ), f"Expected Tensor or None gradient but got {type(grad_val)}" + assert isinstance(grad_val, (torch.Tensor, type(None))), ( + f"Expected Tensor or None gradient but got {type(grad_val)}" + ) stage_output_tensors.append(output_val) output_grad_tensors.append(grad_val) elif isinstance(output_val, (tuple, list)): if grad_val is None: return - assert isinstance( - grad_val, (tuple, list) - ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" + assert isinstance(grad_val, (tuple, list)), ( + f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" + ) assert len(output_val) == len(grad_val) for ov, gv in zip(output_val, grad_val): extract_tensors_with_grads( @@ -350,7 +350,8 @@ def stage_backward( ) torch.autograd.backward( - stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type] + stage_output_tensors, + grad_tensors=output_grad_tensors, # type: ignore[arg-type] ) # Extract gradients wrt the input values diff --git a/torch/distributed/pipelining/microbatch.py b/torch/distributed/pipelining/microbatch.py index a495c639101c..28d5daf8d236 100644 --- a/torch/distributed/pipelining/microbatch.py +++ b/torch/distributed/pipelining/microbatch.py @@ -140,9 +140,9 @@ def _shard_dict_of_args( real_num_chunks = num_chunks first_tensor = True - assert len(args_dict) == len( - args_chunk_spec - ), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" + assert len(args_dict) == len(args_chunk_spec), ( + f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" + ) for arg_key, arg in args_dict.items(): flat, spec = tree_flatten(arg) diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index 1462d6ad8420..e431e29b77e6 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -706,7 +706,9 @@ class Schedule1F1B(PipelineScheduleSingle): recv_work.wait() # Compute - output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + output = self._stage.forward_one_chunk( + fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index] + ) # type: ignore[index] # Clear previous chunk's forward sends (hopefully they have well # finished, otherwise, we are heavily communication bound, in which @@ -762,7 +764,9 @@ class Schedule1F1B(PipelineScheduleSingle): fuse_work.wait() # Now do the fwd - output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index] + output = self._stage.forward_one_chunk( + fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index] + ) # type: ignore[index] # Compute loss self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) @@ -992,9 +996,9 @@ def _add_send_recv( progress = False # go in order of ranks even if dict keys aren't ordered for rank in sorted(compute_actions): - assert ( - len(compute_actions[rank]) > 0 - ), f"{rank=}, {len(compute_actions[rank])=}" + assert len(compute_actions[rank]) > 0, ( + f"{rank=}, {len(compute_actions[rank])=}" + ) action = compute_actions[rank][0] if not _ready_to_schedule(action, prev_actions[rank]): @@ -1026,9 +1030,9 @@ def _validate_schedule( num_stages: int, num_microbatches: int, ) -> dict[int, int]: - assert ( - len(actions) == pp_group_size - ), f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}" + assert len(actions) == pp_group_size, ( + f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}" + ) for rank in range(pp_group_size): assert rank in actions, f"Schedule is missing actions for rank {rank}" @@ -1048,36 +1052,36 @@ def _validate_schedule( for action in actions[rank]: if action is None: continue - assert isinstance( - action, _Action - ), f"Got an invalid action: {action}, expected instance of _Action" + assert isinstance(action, _Action), ( + f"Got an invalid action: {action}, expected instance of _Action" + ) s_id = action.stage_index ctype = action.computation_type mb_id = action.microbatch_index if ctype == F: stage_actions[s_id][F].add(mb_id) elif ctype == B: - assert ( - mb_id in stage_actions[s_id][F] - ), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward" + assert mb_id in stage_actions[s_id][F], ( + f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward" + ) stage_actions[s_id][B].add(mb_id) elif ctype == I: - assert ( - mb_id in stage_actions[s_id][F] - ), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward" + assert mb_id in stage_actions[s_id][F], ( + f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward" + ) stage_actions[s_id][I].add(mb_id) elif ctype == W: - assert ( - mb_id in stage_actions[s_id][I] - ), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input" + assert mb_id in stage_actions[s_id][I], ( + f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input" + ) stage_actions[s_id][W].add(mb_id) if s_id not in stage_index_to_rank_mapping: stage_index_to_rank_mapping[s_id] = rank else: existing_rank = stage_index_to_rank_mapping[s_id] - assert ( - rank == existing_rank - ), f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}" + assert rank == existing_rank, ( + f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}" + ) for s_id in stage_actions: f_mb = len(stage_actions[s_id][F]) @@ -1085,14 +1089,14 @@ def _validate_schedule( i_mb = len(stage_actions[s_id][I]) w_mb = len(stage_actions[s_id][W]) - assert ( - f_mb == num_microbatches - ), f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}" + assert f_mb == num_microbatches, ( + f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}" + ) - assert ( - b_mb + (i_mb + w_mb) // 2 == num_microbatches - ), f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \ + assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, ( + f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \ but got B={b_mb}, I={i_mb}, W={w_mb}" + ) return stage_index_to_rank_mapping @@ -1289,9 +1293,9 @@ class PipelineScheduleMulti(_PipelineSchedule): computation_type = action.computation_type mb_index = action.microbatch_index stage_index = action.stage_index - assert ( - mb_index is not None - ), "All currently supported action types require valid microbatch_index" + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) if computation_type == _ComputationType.FORWARD: # perform forward computation stage = stage_index_to_stage[stage_index] @@ -1362,9 +1366,9 @@ class PipelineScheduleMulti(_PipelineSchedule): computation_type = prev_rank_action.computation_type mb_index = prev_rank_action.microbatch_index stage_index = prev_rank_action.stage_index - assert ( - mb_index is not None - ), "All currently supported action types require valid microbatch_index" + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) # Only handle sends for the forward from a previous rank if computation_type == _ComputationType.FORWARD: # If not the last stage, then receive fwd activations @@ -1393,9 +1397,9 @@ class PipelineScheduleMulti(_PipelineSchedule): computation_type = next_rank_action.computation_type mb_index = next_rank_action.microbatch_index stage_index = next_rank_action.stage_index - assert ( - mb_index is not None - ), "All currently supported action types require valid microbatch_index" + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) # Only handle receives for the backwards from a next rank if computation_type in (FORWARD, BACKWARD_WEIGHT): # Next rank doing forward or weight update has no influence for the current rank backward recv @@ -1503,9 +1507,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): """Dump a CSV representation of the compute + comms schedule into a file with the provided filename.""" # TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible # that it does not exist if it was created from a compute_comms schedule. - assert ( - self.pipeline_order_with_comms is not None - ), "Must initialize compute_comms schedule before dump_csv" + assert self.pipeline_order_with_comms is not None, ( + "Must initialize compute_comms schedule before dump_csv" + ) with open(filename, "w", newline="") as csvfile: writer = csv.writer(csvfile) for rank in self.pipeline_order_with_comms: @@ -1541,9 +1545,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): stage.stage_index: stage for stage in self._stages } - assert ( - self.pipeline_order_with_comms is not None - ), "Must call _load_actions() before calling _step_microbatches()" + assert self.pipeline_order_with_comms is not None, ( + "Must call _load_actions() before calling _step_microbatches()" + ) # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use bwd_recv_ops: dict[tuple[int, int], Work] = {} @@ -1562,9 +1566,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): unshard_ops[stage_idx].wait() del unshard_ops[stage_idx] unsharded_stages.add(stage_idx) - assert ( - stage_idx in unsharded_stages - ), f"Attempted to compute on sharded {stage_idx=}" + assert stage_idx in unsharded_stages, ( + f"Attempted to compute on sharded {stage_idx=}" + ) # count either full_backward or backward_weight together, to determine when to sync DP grads backward_counter: Counter[int] = Counter() @@ -1606,7 +1610,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): assert ( stage_idx, mb_index, - ) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward" + ) not in fwd_recv_ops, ( + "Recv twice for {stage_idx=} {mb_index=} without executing forward" + ) fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( stage.get_fwd_recv_ops(mb_index) ) @@ -1614,7 +1620,9 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): assert ( stage_idx, mb_index, - ) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward" + ) not in bwd_recv_ops, ( + "Recv twice for {stage_idx=} {mb_index=} without executing backward" + ) bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( stage.get_bwd_recv_ops(mb_index) ) @@ -1627,12 +1635,12 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator] elif comp_type == RESHARD: if stage_uses_fsdp: - assert ( - stage_idx in unsharded_stages - ), f"Resharding {stage_idx=} without unsharding" - assert ( - stage_idx not in unshard_ops - ), f"Resharding {stage_idx=} before finishing unshard" + assert stage_idx in unsharded_stages, ( + f"Resharding {stage_idx=} without unsharding" + ) + assert stage_idx not in unshard_ops, ( + f"Resharding {stage_idx=} before finishing unshard" + ) stage.submod.reshard() # type: ignore[operator] elif comp_type == FORWARD: if stage_uses_fsdp: @@ -1739,7 +1747,12 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): ) # TODO(whc) what is the best practice for printing a multiline log? # logger will split it into multiple log lines, but this makes it hard to read (too wide) - print(_format_pipeline_order(self.pipeline_order_with_comms, error_step_number=time_step)) # type: ignore[arg-type] + print( + _format_pipeline_order( + self.pipeline_order_with_comms, # type: ignore[arg-type] + error_step_number=time_step, + ) + ) raise e # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index d966589cbab7..71260fcae517 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -243,16 +243,16 @@ class _PipelineStageBase(ABC): configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches which could show up as hangs, silent corruption, or other errors. """ - assert ( - self._outputs_meta is None - ), "Attempting to reconfigure output_meta, which is not supported" + assert self._outputs_meta is None, ( + "Attempting to reconfigure output_meta, which is not supported" + ) self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] def get_outputs_meta(self) -> tuple[torch.Tensor, ...]: """Get the output metadata (meta tensors) reprensenting the outputs of this stage""" - assert ( - self._outputs_meta is not None - ), "Attempted to get_outputs_meta() without configuring output meta" + assert self._outputs_meta is not None, ( + "Attempted to get_outputs_meta() without configuring output meta" + ) return self._outputs_meta def _create_grad_send_info( @@ -358,12 +358,12 @@ class _PipelineStageBase(ABC): prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs) for info, tensor in zip(recv_infos, prev_stage_outputs): - assert isinstance( - tensor, torch.Tensor - ), f"expected tensor values as outputs from prev stage, got {type(tensor)}" - assert isinstance( - info, _RecvInfo - ), "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" + assert isinstance(tensor, torch.Tensor), ( + f"expected tensor values as outputs from prev stage, got {type(tensor)}" + ) + assert isinstance(info, _RecvInfo), ( + "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" + ) # We don't need to do a data copy here, since we can directly pass the activation tensor reference from # one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve @@ -376,9 +376,9 @@ class _PipelineStageBase(ABC): """ Returns the input grad tensors for this stage, which correspond to the stage inputs during forward. """ - assert ( - self.has_backward - ), "can't steal_bwd_input if this stage doesn't have backward" + assert self.has_backward, ( + "can't steal_bwd_input if this stage doesn't have backward" + ) assert not self.is_first, "can't get bwd output if this stage is first" self._check_chunk_id(mb_index) @@ -391,22 +391,22 @@ class _PipelineStageBase(ABC): Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv. Does not detach or set '_requires_grad'. """ - assert isinstance( - next_stage_bwd_outputs, tuple - ), f"Expected tuple, got {type(next_stage_bwd_outputs)}" + assert isinstance(next_stage_bwd_outputs, tuple), ( + f"Expected tuple, got {type(next_stage_bwd_outputs)}" + ) - assert ( - self.has_backward - ), "can't set bwd input if this stage doesn't have backward" + assert self.has_backward, ( + "can't set bwd input if this stage doesn't have backward" + ) assert not self.is_last, "can't set bwd input if this stage is last" recv_infos = self.grad_recv_info[mb_index] for info, tensor in zip(recv_infos, next_stage_bwd_outputs): - assert isinstance( - tensor, torch.Tensor - ), f"expected tensor values as outputs from prev stage, got {type(tensor)}" - assert isinstance( - info, _RecvInfo - ), f"Expected a recv info, got {type(info)}" + assert isinstance(tensor, torch.Tensor), ( + f"expected tensor values as outputs from prev stage, got {type(tensor)}" + ) + assert isinstance(info, _RecvInfo), ( + f"Expected a recv info, got {type(info)}" + ) info.buffer = tensor def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: @@ -1053,9 +1053,9 @@ class _PipelineStage(_PipelineStageBase): # If the input is a getitem, we need to go deeper arg_node = arg_node.args[0] - assert ( - arg_node.op == "call_module" - ), f"Expecting call_module, got {arg_node.op}" + assert arg_node.op == "call_module", ( + f"Expecting call_module, got {arg_node.op}" + ) src_stage = self.get_stage_index_of_submod(arg_node.name) # Create a receive buffer for this placeholder @@ -1081,7 +1081,8 @@ class _PipelineStage(_PipelineStageBase): args_recv_info: list[InputInfo] = [] # Filter out placeholder nodes from `self.submod` (a GraphModule) placeholders = filter( # type: ignore[var-annotated] - lambda node: node.op == "placeholder", self.submod.graph.nodes # type: ignore[arg-type, union-attr] + lambda node: node.op == "placeholder", # type: ignore[arg-type] + self.submod.graph.nodes, # type: ignore[arg-type,union-attr] ) # `placeholders` are nodes internal to submod. # `self.node.args` are dependency nodes in the outer graph. @@ -1300,9 +1301,9 @@ class PipelineStage(_PipelineStageBase): raise RuntimeError( "Failed to perform pipeline shape inference- are your inputs on the same device as your module?" ) from e - assert ( - output_args is not None - ), "If passing input_args, also pass output_args to override shape inference" + assert output_args is not None, ( + "If passing input_args, also pass output_args to override shape inference" + ) self._configure_outputs_meta( (output_args,) if isinstance(output_args, torch.Tensor) else output_args ) @@ -1346,9 +1347,9 @@ class PipelineStage(_PipelineStageBase): ) args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args) else: - assert ( - len(args) == 0 - ), "Can't supply input args for shape inference on non-first stage" + assert len(args) == 0, ( + "Can't supply input args for shape inference on non-first stage" + ) objects = [None] logger.debug( "Shape inference: stage %s receiving from stage %s", diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index d5169c58161a..356b30cb7f0b 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -80,9 +80,9 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa world_size = world_size_opt if rank != -1 or world_size != -1 or world_size_opt is None: query_dict = _query_to_dict(result.query) - assert ( - "rank" not in query_dict and "world_size" not in query_dict - ), f"The url: {url} has node-specific arguments(rank, world_size) already." + assert "rank" not in query_dict and "world_size" not in query_dict, ( + f"The url: {url} has node-specific arguments(rank, world_size) already." + ) if rank != -1: query_dict["rank"] = str(rank) if world_size != -1 or world_size_opt is None: diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 164ba4056eed..d4a6712e0d66 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -137,13 +137,13 @@ def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None): with _all_gather_dict_lock: if not worker_names: worker_names = _ALL_WORKER_NAMES - assert ( - worker_name in worker_names - ), f"{worker_name} is not expected by leader." + assert worker_name in worker_names, ( + f"{worker_name} is not expected by leader." + ) states = _all_gather_sequence_id_to_states[sequence_id] - assert ( - worker_name not in states.gathered_objects - ), f"{worker_name} reported intent sequence id {sequence_id} twice. " + assert worker_name not in states.gathered_objects, ( + f"{worker_name} reported intent sequence id {sequence_id} twice. " + ) states.gathered_objects[worker_name] = obj if worker_names == set(states.gathered_objects.keys()): states.proceed_signal.set() @@ -153,9 +153,9 @@ def _broadcast_to_followers(sequence_id, objects_map): with _all_gather_dict_lock: states = _all_gather_sequence_id_to_states[sequence_id] - assert ( - not states.proceed_signal.is_set() - ), f"Termination signal sequence id {sequence_id} got set twice." + assert not states.proceed_signal.is_set(), ( + f"Termination signal sequence id {sequence_id} got set twice." + ) states.gathered_objects = objects_map states.proceed_signal.set() @@ -202,9 +202,9 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): function blocks until all workers have received the gathered results. """ if not worker_names: - assert ( - _ALL_WORKER_NAMES is not None - ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." + assert _ALL_WORKER_NAMES is not None, ( + "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." + ) worker_names = _ALL_WORKER_NAMES leader_name = min(worker_names) @@ -930,8 +930,7 @@ def _get_should_profile(): ActiveProfilerType = torch._C._profiler.ActiveProfilerType return ( torch.autograd._profiler_enabled() - and torch._C._autograd._profiler_type() - == ActiveProfilerType.LEGACY # type: ignore[attr-defined] + and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined] ) diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 2be42a38ee25..9f1b13f948d0 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -23,7 +23,7 @@ def _to_device(device: DeviceType) -> torch.device: def _to_device_map( - device_map: dict[DeviceType, DeviceType] + device_map: dict[DeviceType, DeviceType], ) -> dict[torch.device, torch.device]: full_device_map: dict[torch.device, torch.device] = {} reverse_map: dict[torch.device, torch.device] = {} @@ -127,7 +127,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): >>> options = TensorPipeRpcBackendOptions( >>> num_worker_threads=8, >>> device_maps={"worker1": {0: 1}} - >>> # maps worker0's cuda:0 to worker1's cuda:1 + >>> # maps worker0's cuda:0 to worker1's cuda:1 >>> ) >>> options.set_device_map("worker1", {1: 2}) >>> # maps worker0's cuda:1 to worker1's cuda:2 diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index 7b9a4d0bcde9..b0cb1713bcc9 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -63,10 +63,14 @@ class _server_process_global_profile(profile): >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> x, y = torch.tensor(1), torch.tensor(2) - >>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) + >>> outer_profile_rref = rpc.remote( + ... dst_worker_name, rpc._server_process_global_profile + ... ) >>> outer_profile_rref.rpc_sync().__enter__() >>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) - >>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) + >>> inner_profile_rref = rpc.remote( + ... dst_worker_name, rpc._server_process_global_profile + ... ) >>> inner_profile_rref.rpc_sync().__enter__() >>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) >>> inner_profile_rref.rpc_sync().__exit__(None, None, None) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index e050f1e6b586..b1c073dc861f 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -289,9 +289,9 @@ Important Notices :: - >>> # xdoctest: +SKIP("stub") - >>> import torch.distributed as dist - >>> dist.init_process_group(backend="gloo|nccl") + >>> # xdoctest: +SKIP("stub") + >>> import torch.distributed as dist + >>> dist.init_process_group(backend="gloo|nccl") 3. In your training program, you can either use regular distributed functions or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your @@ -302,9 +302,9 @@ Important Notices :: local_rank = int(os.environ["LOCAL_RANK"]) - model = torch.nn.parallel.DistributedDataParallel(model, - device_ids=[local_rank], - output_device=local_rank) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[local_rank], output_device=local_rank + ) Please ensure that ``device_ids`` argument is set to be the only GPU device id that your code will be operating on. This is generally the local rank of the @@ -331,17 +331,18 @@ utility :: - def main(): - load_checkpoint(checkpoint_path) - initialize() - train() + def main(): + load_checkpoint(checkpoint_path) + initialize() + train() - def train(): - for batch in iter(dataset): - train_step(batch) - if should_checkpoint: - save_checkpoint(checkpoint_path) + def train(): + for batch in iter(dataset): + train_step(batch) + + if should_checkpoint: + save_checkpoint(checkpoint_path) 9. (Recommended) On worker errors, this tool will summarize the details of the error (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp) @@ -353,17 +354,19 @@ utility :: - from torch.distributed.elastic.multiprocessing.errors import record + from torch.distributed.elastic.multiprocessing.errors import record - @record - def main(): - # do train - pass - if __name__ == "__main__": - main() + @record + def main(): + # do train + pass + + if __name__ == "__main__": + main() """ # noqa: E501 + import os import sys import uuid diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 41e3ae014dad..ec8e268f51fc 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -297,9 +297,9 @@ class DTensor(torch.Tensor): @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - assert ( - flatten_spec is not None - ), "Expecting spec to be not None from `__tensor_flatten__` return value!" + assert flatten_spec is not None, ( + "Expecting spec to be not None from `__tensor_flatten__` return value!" + ) local_tensor = inner_tensors["_local_tensor"] spec, requires_grad = flatten_spec unflatten_tensor_meta = TensorMeta( @@ -694,9 +694,7 @@ def distribute_tensor( xla_distribute_tensor, ) - return xla_distribute_tensor( - tensor, device_mesh, placements - ) # type:ignore[return-value] + return xla_distribute_tensor(tensor, device_mesh, placements) # type:ignore[return-value] except ImportError as e: msg = "To use DTensor API with xla, you must install the torch_xla package!" raise ImportError(msg) from e @@ -930,7 +928,9 @@ def distribute_module( FutureWarning, stacklevel=2, ) - module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] + module.register_forward_pre_hook( + lambda _, inputs: input_fn(inputs, device_mesh) # type: ignore[call-arg] + ) elif num_args == 3: # input_fn takes in module, inputs, device mesh module.register_forward_pre_hook( @@ -990,9 +990,9 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) # check device_mesh againts placements - assert device_mesh.ndim == len( - placements - ), "mesh dimension does not match the length of placements" + assert device_mesh.ndim == len(placements), ( + "mesh dimension does not match the length of placements" + ) assert kwargs["layout"] == torch.strided, "layout value not supported!" torch_stride = torch._prims_common.make_contiguous_strides_for(size) diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 675625cf9cfd..61d2dbd02055 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -75,7 +75,8 @@ def found_inf_reduce_handler( ) -> None: op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) local_tensor_args = pytree.tree_unflatten( - cast(list[object], op_info.local_args), op_info.args_tree_spec # type: ignore[arg-type] + cast(list[object], op_info.local_args), + op_info.args_tree_spec, # type: ignore[arg-type] ) local_tensor_args = cast(tuple[object, ...], local_tensor_args) op_call(*local_tensor_args, **op_info.local_kwargs) @@ -200,8 +201,9 @@ class OpDispatcher: # did not already construct one random._rng_tracker = random.OffsetBasedRNGTracker(mesh) - first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast( - torch.Tensor, local_tensor_args[0] + first_arg, first_local_arg = ( + cast(dtensor.DTensor, args[0]), + cast(torch.Tensor, local_tensor_args[0]), ) rng_context = ( random._rng_tracker._distribute_region(first_arg._spec) @@ -422,18 +424,18 @@ class OpDispatcher: def wrap(res: object, spec: OutputSpecType) -> object: if isinstance(res, torch.Tensor): if spec is not None: - assert isinstance( - spec, DTensorSpec - ), f"output spec does not match with output! Expected DTensorSpec, got {spec}." + assert isinstance(spec, DTensorSpec), ( + f"output spec does not match with output! Expected DTensorSpec, got {spec}." + ) return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) else: # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor assert res.ndim == 0, "output tensor should be scalar!" return res elif isinstance(res, (list, tuple)): - assert spec is not None and isinstance( - spec, (list, tuple) - ), f"output spec does not match with output! Expected list/tuple, got {spec}." + assert spec is not None and isinstance(spec, (list, tuple)), ( + f"output spec does not match with output! Expected list/tuple, got {spec}." + ) res_list = [] for e, s in zip(res, spec): res_list.append(OpDispatcher.wrap(e, s)) diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index feef87aae1eb..995d4ad73da7 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -152,9 +152,9 @@ class OpStrategy(StrategyType): if isinstance(output_spec, DTensorSpec): return output_spec.mesh.shape else: - assert isinstance( - output_spec, tuple - ), "found no DTensorSpec in the OpStrategy!" + assert isinstance(output_spec, tuple), ( + "found no DTensorSpec in the OpStrategy!" + ) assert output_spec[0] is not None return output_spec[0].mesh.shape diff --git a/torch/distributed/tensor/_ops/_einsum_strategy.py b/torch/distributed/tensor/_ops/_einsum_strategy.py index 0db79ed2f700..5953721d219c 100644 --- a/torch/distributed/tensor/_ops/_einsum_strategy.py +++ b/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -63,9 +63,9 @@ class EinsumDims: if is_batch_dim: batch_dims.append(dim_char) else: - assert ( - len(input_dims) == 2 - ), "free dimension only supported for two inputs!" + assert len(input_dims) == 2, ( + "free dimension only supported for two inputs!" + ) lhs, rhs = input_dims if dim_char in lhs: lhs_out_only_dims.append(dim_char) diff --git a/torch/distributed/tensor/_ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py index 70295d7ad90e..e2e29d77800b 100644 --- a/torch/distributed/tensor/_ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -89,9 +89,9 @@ class _MaskPartial(Partial): # override parent logic to perform partial mask for embedding num_chunks = mesh.size(mesh_dim) # get local shard size and offset on the embedding_dim - assert ( - self.offset_shape is not None - ), "offset_shape needs to be set for _MaskPartial" + assert self.offset_shape is not None, ( + "offset_shape needs to be set for _MaskPartial" + ) local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim( self.offset_shape[self.offset_dim], num_chunks, diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 597279c63fc5..04d0374f7d65 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -994,9 +994,9 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy ) output_specs_list.append(weight_out_spec if output_mask[1] else None) else: - assert ( - output_mask[1] is False - ), "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." + assert output_mask[1] is False, ( + "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." + ) output_specs_list.append(None) # arg: bias @@ -1020,9 +1020,9 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy ) output_specs_list.append(bias_out_spec if output_mask[2] else None) else: - assert ( - output_mask[2] is False - ), "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." + assert output_mask[2] is False, ( + "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." + ) output_specs_list.append(None) out_tuple_strategy.strategies.append( diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index b2127b05a38c..fc0448e3ebb7 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -155,9 +155,9 @@ def _scaled_mm_like_strategy( assert isinstance(scale_mat2_strategy, OpStrategy) # TODO: add support for these later assert bias_strategy is None, "_scaled_mm on DTensors doesn't support bias" - assert ( - scale_result_strategy is None - ), "_scaled_mm on DTensors doesn't support scale_result" + assert scale_result_strategy is None, ( + "_scaled_mm on DTensors doesn't support scale_result" + ) # generate all possible strategies for mm mm_strategy = gen_einsum_strategies(mm_equation, mesh) # filter out invalid strategies and associate costs diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index db238cadbff9..a26ab5ecc6c5 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -445,9 +445,9 @@ def pointwise_strategy( followed_strategy = op_schema.args_schema[max_shards_strategy_index] - assert isinstance( - followed_strategy, OpStrategy - ), f"no strategy to follow for {op_schema}!" + assert isinstance(followed_strategy, OpStrategy), ( + f"no strategy to follow for {op_schema}!" + ) return common_pointwise_strategy( mesh, op_schema.args_schema, followed_strategy, linearity ) diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index 3191208203b9..0326dbe8fd34 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -254,9 +254,9 @@ def dim_movedim( def dim_repeat(ndim: int, sizes: Shape) -> DimMap: sizes = normalize_sizes(sizes) - assert ( - len(sizes) >= ndim - ), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." + assert len(sizes) >= ndim, ( + f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." + ) pad = len(sizes) - ndim return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple( Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:]) @@ -275,9 +275,9 @@ def infer_size(total_size: int, sizes: Shape) -> Shape: if infers: size = -size missing_size = total_size // size - assert ( - total_size % size == 0 - ), f"size inferred for -1 is not integral {sizes} should have {total_size} elements." + assert total_size % size == 0, ( + f"size inferred for -1 is not integral {sizes} should have {total_size} elements." + ) return tuple(s if s != -1 else missing_size for s in sizes) assert size == total_size, f"sizes do not match {total_size} vs {size}" return sizes @@ -538,9 +538,9 @@ def propagate_shape_and_sharding( for size, shard in zip(mesh_sizes, input_src_placements): if isinstance(shard, Shard) and shard.dim == in_dim: submesh_size *= size - assert ( - out_size % submesh_size == 0 - ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." + assert out_size % submesh_size == 0, ( + f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." + ) # we will only shard our first component of the split return in_dim if cmd.split_id == 0 else None diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index 8cd39ba7f943..7f1894ed73ff 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -45,7 +45,7 @@ def register_prop_rule( # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def wrapper( - impl: Callable[[OpSchema], OutputSharding] + impl: Callable[[OpSchema], OutputSharding], ) -> Callable[[OpSchema], OutputSharding]: overloads = op if isinstance(op, list) else [op] for overload in overloads: @@ -102,7 +102,7 @@ def register_op_strategy( def as_list( - x: Union[list[object], object] + x: Union[list[object], object], # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. ) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type] # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args, diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index f1ab11d9ac5b..0d80225e7c2b 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -231,9 +231,9 @@ def redistribute_local_tensor( local_tensor, device_mesh, i, my_coordinate[i] ) else: - assert ( - current.is_shard() - ), f"Current placement should be shard but found {current}" + assert current.is_shard(), ( + f"Current placement should be shard but found {current}" + ) shard_spec = cast(Shard, current) if shard_spec.dim != target_placement.dim: new_local_tensor = shard_spec._to_new_shard_dim( diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index e81957506d63..29a836516144 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -487,9 +487,9 @@ class ShardingPropagator: strategy_costs: list[float] = [] for strtg in strategy.strategies: - assert ( - strtg.redistribute_cost is not None - ), "must set redistribute cost each strategy!" + assert strtg.redistribute_cost is not None, ( + "must set redistribute cost each strategy!" + ) redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost)) strategy_costs.append(redistribute_cost) diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index d4ea5ce844bc..61705610f08f 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -73,9 +73,9 @@ def compute_local_shape_and_global_offset( if isinstance(placement, Shard): shard_dim = placement.dim local_offset = [0] * len(global_shape) - assert shard_dim < len( - local_shape - ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + ) shard_size, shard_offset = placement._local_shard_size_on_dim( local_shape[shard_dim], mesh_dim_size, @@ -141,16 +141,15 @@ def compute_local_shape_and_global_offset( if isinstance(placement, _StridedShard): strided_part_seen[shard_dim] = True - shard_idx_stride_by_mesh_dim[shard_dim][ - idx - ] = num_shards_by_tensor_dim[shard_dim] // ( - placement.split_factor * mesh_dim_size + shard_idx_stride_by_mesh_dim[shard_dim][idx] = ( + num_shards_by_tensor_dim[shard_dim] + // (placement.split_factor * mesh_dim_size) ) else: num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size - shard_idx_stride_by_mesh_dim[shard_dim][ - idx - ] = num_shards_by_tensor_dim[shard_dim] + shard_idx_stride_by_mesh_dim[shard_dim][idx] = ( + num_shards_by_tensor_dim[shard_dim] + ) shard_idx = [ sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)]) @@ -205,9 +204,9 @@ def compute_global_tensor_info( ) shard_dim = shard_placement.dim - assert ( - shard_dim < tensor.ndim - ), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}." + assert shard_dim < tensor.ndim, ( + f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}." + ) local_dim_size = tensor_shape[shard_dim] tensor_shape[shard_dim] = local_dim_size * mesh_dim_size diff --git a/torch/distributed/tensor/debug/_comm_mode.py b/torch/distributed/tensor/debug/_comm_mode.py index 6ca47c2e8fde..570161b67682 100644 --- a/torch/distributed/tensor/debug/_comm_mode.py +++ b/torch/distributed/tensor/debug/_comm_mode.py @@ -283,9 +283,9 @@ class CommDebugMode(TorchDispatchMode): "module_type" in self.advanced_module_tracker.module_helper_dict[fqn] and include_module_data ): - json_dict[ - "module_type" - ] = self.advanced_module_tracker.module_helper_dict[fqn]["module_type"] + json_dict["module_type"] = ( + self.advanced_module_tracker.module_helper_dict[fqn]["module_type"] + ) if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]: for ( @@ -659,9 +659,9 @@ class CommDebugMode(TorchDispatchMode): operation_dict["is_bw"] = self.advanced_module_tracker.is_bw # tracks if the operation is part of activation checkpointing - operation_dict[ - "is_activation_checkpointing" - ] = self.advanced_module_tracker.activation_checkpointing + operation_dict["is_activation_checkpointing"] = ( + self.advanced_module_tracker.activation_checkpointing + ) if any(t == DTensor for t in types): for ele in args: diff --git a/torch/distributed/tensor/debug/_visualize_sharding.py b/torch/distributed/tensor/debug/_visualize_sharding.py index 9c1c6df32ed7..fc476514bf55 100644 --- a/torch/distributed/tensor/debug/_visualize_sharding.py +++ b/torch/distributed/tensor/debug/_visualize_sharding.py @@ -108,9 +108,9 @@ def _compute_local_shape_and_global_offset( if isinstance(placement, Shard): shard_dim = placement.dim local_offset = [0] * len(global_shape) - assert shard_dim < len( - local_shape - ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + ) shard_size, shard_offset = placement._local_shard_size_on_dim( local_shape[shard_dim], mesh_dim_size, diff --git a/torch/distributed/tensor/examples/comm_mode_features_example.py b/torch/distributed/tensor/examples/comm_mode_features_example.py index 02906f0dbf1a..da004aef4071 100644 --- a/torch/distributed/tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/tensor/examples/comm_mode_features_example.py @@ -2,6 +2,7 @@ To run the example, use the following command: torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.py -e MLP_operation_tracing """ + import argparse import os from typing import Callable, Union diff --git a/torch/distributed/tensor/examples/convnext_example.py b/torch/distributed/tensor/examples/convnext_example.py index 04b04b69ee9c..9a3c2bbabd9e 100644 --- a/torch/distributed/tensor/examples/convnext_example.py +++ b/torch/distributed/tensor/examples/convnext_example.py @@ -6,6 +6,7 @@ with intermediate activations sharded across mutliple GPUs via DTensor To run the example, use the following command: torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py """ + import os import time diff --git a/torch/distributed/tensor/examples/torchrec_sharding_example.py b/torch/distributed/tensor/examples/torchrec_sharding_example.py index 9261220d175e..b78455bfebd9 100644 --- a/torch/distributed/tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/tensor/examples/torchrec_sharding_example.py @@ -3,6 +3,7 @@ The following example demonstrates how to represent torchrec's embedding sharding with the DTensor API. """ + import argparse import os from functools import cached_property diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 8b1ecbe1dd50..03f051320aea 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -253,22 +253,18 @@ class _AttentionOp(Protocol): key: torch.Tensor, value: torch.Tensor, **kwargs: object, - ) -> tuple[torch.Tensor, ...]: - ... + ) -> tuple[torch.Tensor, ...]: ... class _RingRotater(ABC): @abstractmethod - def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: - ... + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: ... @abstractmethod - def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: - ... + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: ... @abstractmethod - def next_buffer(self) -> torch.Tensor: - ... + def next_buffer(self) -> torch.Tensor: ... class _AllToAllRotater(_RingRotater): @@ -1097,15 +1093,13 @@ class _LoadBalancer(ABC): @abstractmethod def shard( cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... @classmethod @abstractmethod def unshard( cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... class _SequentialSharder(_LoadBalancer): @@ -1147,9 +1141,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer): def shard( cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int ) -> torch.Tensor: - assert ( - cls.ROUND_ROBIN_CYCLE == 2 - ), "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + assert cls.ROUND_ROBIN_CYCLE == 2, ( + "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + ) cp_world_size = mesh.size() cp_rank = mesh.get_local_rank() assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0 @@ -1163,9 +1157,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer): def unshard( cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int ) -> torch.Tensor: - assert ( - cls.ROUND_ROBIN_CYCLE == 2 - ), "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + assert cls.ROUND_ROBIN_CYCLE == 2, ( + "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + ) buffer = buffer.contiguous() cp_world_size = mesh.size() diff --git a/torch/distributed/tensor/experimental/_func_map.py b/torch/distributed/tensor/experimental/_func_map.py index ae02e5c391cd..51861141af5b 100644 --- a/torch/distributed/tensor/experimental/_func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -113,9 +113,15 @@ def local_map( >>> device_mesh=device_mesh, >>> ) >>> - >>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor - >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor - >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors + >>> W_dt = distribute_tensor( + ... W, device_mesh, (col_wise) + ... ) # col-wisely sharded W tensor + >>> X_dt = distribute_tensor( + ... X, device_mesh, (row_wise) + ... ) # row-wisely sharded X tensor + >>> Y_dt = local_mm_allreduce_forward( + ... device_mesh, W_dt, X_dt + ... ) # apply local_mm_allreduce_forward to DTensors .. note:: This API is currently experimental and subject to change """ @@ -151,9 +157,9 @@ def local_map( ) if in_placements is not None: spec = in_placements[idx] - assert ( - spec is not None - ), f"DTensor input {arg} expects placements but received {spec}!" + assert spec is not None, ( + f"DTensor input {arg} expects placements but received {spec}!" + ) if not isinstance(spec, tuple): spec = tuple(spec) @@ -208,17 +214,17 @@ def local_map( ) for out, spec in zip(flat_out, out_placements_tuple): if isinstance(out, torch.Tensor): - assert not isinstance( - out, DTensor - ), f"torch.Tensor output expected but received {type(out)}: {out}" + assert not isinstance(out, DTensor), ( + f"torch.Tensor output expected but received {type(out)}: {out}" + ) flat_dist_out.append( DTensor.from_local(out, device_mesh, spec, run_check=False) ) else: - assert ( - spec is None - ), f"Non-tensor output {out} expects None placements but received {spec}!" + assert spec is None, ( + f"Non-tensor output {out} expects None placements but received {spec}!" + ) flat_dist_out.append(out) diff --git a/torch/distributed/tensor/experimental/_tp_transform.py b/torch/distributed/tensor/experimental/_tp_transform.py index 679a91c47ab5..52de6cebe684 100644 --- a/torch/distributed/tensor/experimental/_tp_transform.py +++ b/torch/distributed/tensor/experimental/_tp_transform.py @@ -188,9 +188,14 @@ def _mark_sharding( """ Mark the sharding strategy for each node in the graph module. """ - placement_strategies: dict[ - Node, PlacementStrategy - ] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements) + placement_strategies: dict[Node, PlacementStrategy] = ( + _mark_tensor_parallel_shardings( + gm, + graph_signature, + mesh, + parameter_placements, + ) + ) for node in gm.graph.nodes: if node.op == "placeholder": @@ -202,9 +207,9 @@ def _mark_sharding( elif node.op == "call_function": if node.target == operator.getitem: input_nodes = node.all_input_nodes - assert ( - len(input_nodes) == 1 - ), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}" + assert len(input_nodes) == 1, ( + f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}" + ) arg_strategy = placement_strategies[input_nodes[0]] placement_strategies[node] = _create_placement_strategy( node, diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index 9f3e8c3e268a..5282542950c4 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -328,7 +328,9 @@ class DTensorExtensions(FSDPExtensions): self.device_handle = device_handle # we have to use the dynamo disable this way to disable dynamo as the decorater way would # trigger build failure with torch deploy... - self.post_unflatten_transform = torch._dynamo.disable(self.post_unflatten_transform) # type: ignore[method-assign] + self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign] + self.post_unflatten_transform + ) def pre_flatten_transform( self, diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index 8d7ecbb83e50..de003c599468 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -64,9 +64,7 @@ def input_reshard( return module -def _pack_hook_tp( - mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor -) -> Any: # noqa: D401 +def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401 """Hook function called after FWD to shard input.""" if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements): return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) @@ -84,9 +82,7 @@ def _pack_hook_tp( return x -def _unpack_hook_tp( - mesh: DeviceMesh, input_reshard_dim: int, x: Any -) -> torch.Tensor: # noqa: D401 +def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401 """Hook function called before activation recomputing in BWD to restore input.""" if ( isinstance(x, DTensor) diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 8c3c723b502c..ca0ba2b7d296 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -38,8 +38,7 @@ class ParallelStyle(ABC): src_data_rank: Optional[int] = 0 @abstractmethod - def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - ... + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ... class ColwiseParallel(ParallelStyle): @@ -467,19 +466,21 @@ class PrepareModuleInput(ParallelStyle): ) self.use_local_output = use_local_output if self.input_layouts is not None: - assert ( - self.desired_input_layouts is not None - ), "desired module inputs should not be None!" - assert len(self.input_layouts) == len( - self.desired_input_layouts - ), "input_layouts and desired_input_layouts should have same length!" + assert self.desired_input_layouts is not None, ( + "desired module inputs should not be None!" + ) + assert len(self.input_layouts) == len(self.desired_input_layouts), ( + "input_layouts and desired_input_layouts should have same length!" + ) self.with_kwargs = input_kwarg_layouts is not None self.input_kwarg_layouts = input_kwarg_layouts or {} self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} if self.with_kwargs: assert len(self.input_kwarg_layouts) == len( self.desired_input_kwarg_layouts - ), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" + ), ( + "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" + ) def _prepare_input_arg( self, @@ -494,9 +495,9 @@ class PrepareModuleInput(ParallelStyle): # assert inp.placements[0] == input_layout dt_inp = input else: - assert isinstance( - input, torch.Tensor - ), "expecting input to be a torch.Tensor!" + assert isinstance(input, torch.Tensor), ( + "expecting input to be a torch.Tensor!" + ) dt_inp = DTensor.from_local( input, mesh, (input_layout,), run_check=False ) @@ -517,9 +518,9 @@ class PrepareModuleInput(ParallelStyle): if len(inputs) != len(self.input_layouts): raise ValueError("module inputs and input_layouts should have same length!") - assert ( - self.desired_input_layouts is not None - ), "desired module inputs should not be None!" + assert self.desired_input_layouts is not None, ( + "desired module inputs should not be None!" + ) for inp, input_layout, desired_layout in zip( inputs, self.input_layouts, self.desired_input_layouts ): @@ -551,7 +552,9 @@ class PrepareModuleInput(ParallelStyle): with_kwargs=True, ) # type: ignore[misc] else: - module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg] + module.register_forward_pre_hook( + lambda _, inputs: self._prepare_input_fn(inputs, device_mesh) + ) # type: ignore[misc, call-arg] return module @@ -611,9 +614,9 @@ class PrepareModuleOutput(ParallelStyle): else desired_output_layouts ) self.use_local_output = use_local_output - assert len(self.output_layouts) == len( - self.desired_output_layouts - ), "output_layouts and desired_output_layouts should have same length!" + assert len(self.output_layouts) == len(self.desired_output_layouts), ( + "output_layouts and desired_output_layouts should have same length!" + ) def _prepare_out_fn(self, outputs, device_mesh): prepared_outputs = [] @@ -649,5 +652,7 @@ class PrepareModuleOutput(ParallelStyle): return tuple(prepared_outputs) def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)) # type: ignore[misc, call-arg] + module.register_forward_hook( + lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh) + ) # type: ignore[misc, call-arg] return module diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 87d769a60a91..ceb9f170fd3e 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -83,9 +83,9 @@ class Shard(Placement): few ranks before calling the collectives (i.e. scatter/all_gather, etc.). This is because collectives usually require equal size tensor inputs """ - assert ( - self.dim <= tensor.ndim - ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + assert self.dim <= tensor.ndim, ( + f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + ) # chunk tensor over dimension `dim` into n slices tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) @@ -468,9 +468,9 @@ class _StridedShard(Shard): """ TODO: currently _StridedShard does not support padding """ - assert ( - self.dim <= tensor.ndim - ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + assert self.dim <= tensor.ndim, ( + f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + ) total_split = num_chunks * self.split_factor assert tensor.size(self.dim) % total_split == 0, ( diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index 3e608778e5a1..ebe36d0eb1bb 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -68,9 +68,9 @@ def _unpack_kwargs( flat_args: tuple[Any, ...], kwarg_keys: tuple[str, ...] ) -> tuple[tuple[Any, ...], dict[str, Any]]: """See _pack_kwargs.""" - assert len(kwarg_keys) <= len( - flat_args - ), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + assert len(kwarg_keys) <= len(flat_args), ( + f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + ) if len(kwarg_keys) == 0: return flat_args, {} args = flat_args[: -len(kwarg_keys)] @@ -85,15 +85,13 @@ T = TypeVar("T", torch.Tensor, PackedSequence) @overload def _recursive_to( inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool -) -> list[S]: - ... +) -> list[S]: ... @overload def _recursive_to( inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool -) -> tuple[T]: - ... +) -> tuple[T]: ... def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies): @@ -209,13 +207,13 @@ R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any) @overload -def _apply_to_tensors(fn: Callable[[torch.Tensor], Q], container: torch.Tensor) -> Q: - ... +def _apply_to_tensors( + fn: Callable[[torch.Tensor], Q], container: torch.Tensor +) -> Q: ... @overload -def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R: - ... +def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R: ... def _apply_to_tensors(fn, container): diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index 77cd5d33336b..105038641bcc 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -35,6 +35,7 @@ class Bernoulli(ExponentialFamily): probs (Number, Tensor): the probability of sampling `1` logits (Number, Tensor): the log-odds of sampling `1` """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.boolean has_enumerate_support = True diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index de4370e889c8..e030b648a88e 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -28,6 +28,7 @@ class Beta(ExponentialFamily): concentration0 (float or Tensor): 2nd concentration parameter of the distribution (often referred to as beta) """ + arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 453cbb8de6b3..6cbfae150844 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -42,6 +42,7 @@ class Binomial(Distribution): probs (Tensor): Event probabilities logits (Tensor): Event log-odds """ + arg_constraints = { "total_count": constraints.nonnegative_integer, "probs": constraints.unit_interval, diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 91d09af5053f..715429c66552 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -47,6 +47,7 @@ class Categorical(Distribution): probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} has_enumerate_support = True diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 9bfc0a390a61..582c08ebb858 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -29,6 +29,7 @@ class Cauchy(Distribution): loc (float or Tensor): mode or median of the distribution. scale (float or Tensor): half width at half maximum. """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real has_rsample = True diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py index e9403c7dc377..f175bc44f69e 100644 --- a/torch/distributions/chi2.py +++ b/torch/distributions/chi2.py @@ -22,6 +22,7 @@ class Chi2(Gamma): Args: df (float or Tensor): shape parameter of the distribution """ + arg_constraints = {"df": constraints.positive} def __init__(self, df, validate_args=None): diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index 582485179f19..8907e5b467ab 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -23,7 +23,7 @@ suitable for coordinate-wise optimization algorithms like Adam:: loc = torch.zeros(100, requires_grad=True) unconstrained = torch.zeros(100, requires_grad=True) - scale = transform_to(Normal.arg_constraints['scale'])(unconstrained) + scale = transform_to(Normal.arg_constraints["scale"])(unconstrained) loss = -Normal(loc, scale).log_prob(data).sum() The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where @@ -66,7 +66,6 @@ You can create your own registry by creating a new :class:`ConstraintRegistry` object. """ - from torch.distributions import constraints, transforms from torch.types import _Number @@ -127,9 +126,9 @@ class ConstraintRegistry: Looks up a transform to constrained space, given a constraint object. Usage:: - constraint = Normal.arg_constraints['scale'] + constraint = Normal.arg_constraints["scale"] scale = transform_to(constraint)(torch.zeros(1)) # constrained - u = transform_to(constraint).inv(scale) # unconstrained + u = transform_to(constraint).inv(scale) # unconstrained Args: constraint (:class:`~torch.distributions.constraints.Constraint`): diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 6763bd841bc3..dc27b170bb48 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -165,7 +165,7 @@ def is_dependent(constraint): >>> from torch.distributions import Bernoulli >>> from torch.distributions.constraints import is_dependent - >>> dist = Bernoulli(probs = torch.tensor([0.6], requires_grad=True)) + >>> dist = Bernoulli(probs=torch.tensor([0.6], requires_grad=True)) >>> constraint1 = dist.arg_constraints["probs"] >>> constraint2 = dist.arg_constraints["logits"] @@ -187,6 +187,7 @@ class _DependentProperty(property, _Dependent): def __init__(self, low, high): self.low = low self.high = high + @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return constraints.interval(self.low, self.high) @@ -217,8 +218,7 @@ class _DependentProperty(property, _Dependent): Support for syntax to customize static attributes:: @constraints.dependent_property(is_discrete=True, event_dim=1) - def support(self): - ... + def support(self): ... """ return _DependentProperty( fn, is_discrete=self._is_discrete, event_dim=self._event_dim diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py index 66093a81b2c1..b1e8eddfb0ec 100644 --- a/torch/distributions/continuous_bernoulli.py +++ b/torch/distributions/continuous_bernoulli.py @@ -45,6 +45,7 @@ class ContinuousBernoulli(ExponentialFamily): autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019. https://arxiv.org/abs/1907.06845 """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.unit_interval _mean_carrier_measure = 0 diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index b5d51bc60a6c..f656a0582e89 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -47,6 +47,7 @@ class Dirichlet(ExponentialFamily): concentration (Tensor): concentration parameter of the distribution (often referred to as alpha) """ + arg_constraints = { "concentration": constraints.independent(constraints.positive, 1) } diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index ae039fc76178..8ca2636e1f52 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -24,6 +24,7 @@ class Exponential(ExponentialFamily): Args: rate (float or Tensor): rate = 1 / scale of the distribution """ + arg_constraints = {"rate": constraints.positive} support = constraints.nonnegative has_rsample = True diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py index be05aea149bb..053686c6de07 100644 --- a/torch/distributions/fishersnedecor.py +++ b/torch/distributions/fishersnedecor.py @@ -26,6 +26,7 @@ class FisherSnedecor(Distribution): df1 (float or Tensor): degrees of freedom parameter 1 df2 (float or Tensor): degrees of freedom parameter 2 """ + arg_constraints = {"df1": constraints.positive, "df2": constraints.positive} support = constraints.positive has_rsample = True diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py index 7c3277d9e13e..5e0fe3fc7823 100644 --- a/torch/distributions/gamma.py +++ b/torch/distributions/gamma.py @@ -31,6 +31,7 @@ class Gamma(ExponentialFamily): rate (float or Tensor): rate parameter of the distribution (often referred to as beta), rate = 1 / scale """ + arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index b23b737de010..b8b05142db5b 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -41,6 +41,7 @@ class Geometric(Distribution): probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1] logits (Number, Tensor): the log-odds of sampling `1`. """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.nonnegative_integer diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index e550982997dd..623cc7edbda6 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -29,6 +29,7 @@ class Gumbel(TransformedDistribution): loc (float or Tensor): Location parameter of the distribution scale (float or Tensor): Scale parameter of the distribution """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py index 94e98b5aca0a..da17c40da2ed 100644 --- a/torch/distributions/half_cauchy.py +++ b/torch/distributions/half_cauchy.py @@ -29,6 +29,7 @@ class HalfCauchy(TransformedDistribution): Args: scale (float or Tensor): scale of the full Cauchy distribution """ + arg_constraints = {"scale": constraints.positive} support = constraints.nonnegative has_rsample = True diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py index 311774c6cadb..5850f883e908 100644 --- a/torch/distributions/half_normal.py +++ b/torch/distributions/half_normal.py @@ -29,6 +29,7 @@ class HalfNormal(TransformedDistribution): Args: scale (float or Tensor): scale of the full Normal distribution """ + arg_constraints = {"scale": constraints.positive} support = constraints.nonnegative has_rsample = True diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index 6f9cf926f4ee..0442a4c1b483 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -40,6 +40,7 @@ class Independent(Distribution): reinterpreted_batch_ndims (int): the number of batch dims to reinterpret as event dims """ + arg_constraints: dict[str, constraints.Constraint] = {} def __init__( diff --git a/torch/distributions/inverse_gamma.py b/torch/distributions/inverse_gamma.py index f221663ceb6e..aaee976b7f17 100644 --- a/torch/distributions/inverse_gamma.py +++ b/torch/distributions/inverse_gamma.py @@ -31,6 +31,7 @@ class InverseGamma(TransformedDistribution): rate (float or Tensor): rate = 1 / scale of the distribution (often referred to as beta) """ + arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py index 722705e0d219..d38efb631e86 100644 --- a/torch/distributions/kumaraswamy.py +++ b/torch/distributions/kumaraswamy.py @@ -37,6 +37,7 @@ class Kumaraswamy(TransformedDistribution): concentration0 (float or Tensor): 2nd concentration parameter of the distribution (often referred to as beta) """ + arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index 72463f228ac6..39ef9b1efdb7 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -25,6 +25,7 @@ class Laplace(Distribution): loc (float or Tensor): mean of the distribution scale (float or Tensor): scale of the distribution """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real has_rsample = True diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py index 479568bdd428..a18f2ed9f52a 100644 --- a/torch/distributions/lkj_cholesky.py +++ b/torch/distributions/lkj_cholesky.py @@ -57,6 +57,7 @@ class LKJCholesky(Distribution): Daniel Lewandowski, Dorota Kurowicka, Harry Joe. Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008 """ + arg_constraints = {"concentration": constraints.positive} support = constraints.corr_cholesky diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py index 4b9f4f7217c9..a048f94286c8 100644 --- a/torch/distributions/log_normal.py +++ b/torch/distributions/log_normal.py @@ -28,6 +28,7 @@ class LogNormal(TransformedDistribution): loc (float or Tensor): mean of log of distribution scale (float or Tensor): standard deviation of log of the distribution """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.positive has_rsample = True diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py index 7ae28eda3804..a8f7c099d1e8 100644 --- a/torch/distributions/logistic_normal.py +++ b/torch/distributions/logistic_normal.py @@ -32,6 +32,7 @@ class LogisticNormal(TransformedDistribution): tensor([ 0.7653, 0.0341, 0.0579, 0.1427]) """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.simplex has_rsample = True diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index 97257a0edcf5..c6f739a595a3 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -61,7 +61,9 @@ class LowRankMultivariateNormal(Distribution): Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) >>> # xdoctest: +IGNORE_WANT("non-deterministic") - >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2)) + >>> m = LowRankMultivariateNormal( + ... torch.zeros(2), torch.tensor([[1.0], [0.0]]), torch.ones(2) + ... ) >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` tensor([-0.2102, -0.5429]) @@ -82,6 +84,7 @@ class LowRankMultivariateNormal(Distribution): capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor """ + arg_constraints = { "loc": constraints.real_vector, "cov_factor": constraints.independent(constraints.real, 2), diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index 4d76e2f505a9..1fc2c1052d03 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -51,6 +51,7 @@ class MixtureSameFamily(Distribution): component_distribution: `torch.distributions.Distribution`-like instance. Right-most batch dimension indexes component. """ + arg_constraints: dict[str, constraints.Constraint] = {} has_rsample = False diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 2863167748fa..85a227f5c403 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -47,6 +47,7 @@ class Multinomial(Distribution): probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} total_count: int diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index fa6657ff0644..849ee4170015 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -121,6 +121,7 @@ class MultivariateNormal(Distribution): :attr:`precision_matrix` is passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition. """ + arg_constraints = { "loc": constraints.real_vector, "covariance_matrix": constraints.positive_definite, diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index 44cb4a826ee9..e5b0e128efe6 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -30,6 +30,7 @@ class NegativeBinomial(Distribution): probs (Tensor): Event probabilities of success in the half open interval [0, 1) logits (Tensor): Event log-odds for probabilities of success """ + arg_constraints = { "total_count": constraints.greater_than_eq(0), "probs": constraints.half_open_interval(0.0, 1.0), diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 6fd33faf0732..86e30ba450f5 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -29,6 +29,7 @@ class Normal(ExponentialFamily): scale (float or Tensor): standard deviation of the distribution (often referred to as sigma) """ + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real has_rsample = True diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 909f4f63dcaf..7e0bc03c5aba 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -39,6 +39,7 @@ class OneHotCategorical(Distribution): probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.one_hot has_enumerate_support = True @@ -125,6 +126,7 @@ class OneHotCategoricalStraightThrough(OneHotCategorical): [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (Bengio et al., 2013) """ + has_rsample = True def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 2ef4a276926f..2cc1e298ba25 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -27,6 +27,7 @@ class Pareto(TransformedDistribution): scale (float or Tensor): Scale parameter of the distribution alpha (float or Tensor): Shape parameter of the distribution """ + arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} def __init__( diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index face4a05e570..c3b4bacc54cb 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -29,6 +29,7 @@ class Poisson(ExponentialFamily): Args: rate (Number, Tensor): the rate parameter """ + arg_constraints = {"rate": constraints.nonnegative} support = constraints.nonnegative_integer diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index 9d781c430294..4c1549660313 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -37,6 +37,7 @@ class LogitRelaxedBernoulli(Distribution): [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al., 2017) """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.real @@ -126,6 +127,7 @@ class RelaxedBernoulli(TransformedDistribution): probs (Number, Tensor): the probability of sampling `1` logits (Number, Tensor): the log-odds of sampling `1` """ + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.unit_interval has_rsample = True diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index c319335175be..97ae3ed1857b 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -35,6 +35,7 @@ class ExpRelaxedCategorical(Distribution): [2] Categorical Reparametrization with Gumbel-Softmax (Jang et al., 2017) """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = ( constraints.real_vector @@ -116,6 +117,7 @@ class RelaxedOneHotCategorical(TransformedDistribution): probs (Tensor): event probabilities logits (Tensor): unnormalized log probability for each event """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.simplex has_rsample = True diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py index 5559af8ffac3..e141939b2745 100644 --- a/torch/distributions/studentT.py +++ b/torch/distributions/studentT.py @@ -29,6 +29,7 @@ class StudentT(Distribution): loc (float or Tensor): mean of the distribution scale (float or Tensor): scale of the distribution """ + arg_constraints = { "df": constraints.positive, "loc": constraints.real, diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 25e62a03b55a..02792ce9d309 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -46,6 +46,7 @@ class TransformedDistribution(Distribution): :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` """ + arg_constraints: dict[str, constraints.Constraint] = {} def __init__(self, base_distribution, transforms, validate_args=None): diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 2aa91f881d3a..8958f1a63c87 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -543,6 +543,7 @@ class ExpTransform(Transform): r""" Transform via the mapping :math:`y = \exp(x)`. """ + domain = constraints.real codomain = constraints.positive bijective = True @@ -565,6 +566,7 @@ class PowerTransform(Transform): r""" Transform via the mapping :math:`y = x^{\text{exponent}}`. """ + domain = constraints.positive codomain = constraints.positive bijective = True @@ -612,6 +614,7 @@ class SigmoidTransform(Transform): r""" Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. """ + domain = constraints.real codomain = constraints.unit_interval bijective = True @@ -637,6 +640,7 @@ class SoftplusTransform(Transform): Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`. The implementation reverts to the linear function when :math:`x > 20`. """ + domain = constraints.real codomain = constraints.positive bijective = True @@ -660,15 +664,24 @@ class TanhTransform(Transform): Transform via the mapping :math:`y = \tanh(x)`. It is equivalent to - ``` - ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]) - ``` + + .. code-block:: python + + ComposeTransform( + [ + AffineTransform(0.0, 2.0), + SigmoidTransform(), + AffineTransform(-1.0, 2.0), + ] + ) + However this might not be numerically stable, thus it is recommended to use `TanhTransform` instead. Note that one should use `cache_size=1` when it comes to `NaN/Inf` values. """ + domain = constraints.real codomain = constraints.interval(-1.0, 1.0) bijective = True @@ -692,9 +705,8 @@ class TanhTransform(Transform): class AbsTransform(Transform): - r""" - Transform via the mapping :math:`y = |x|`. - """ + r"""Transform via the mapping :math:`y = |x|`.""" + domain = constraints.real codomain = constraints.positive @@ -719,6 +731,7 @@ class AffineTransform(Transform): for univariate random variables, 1 for distributions over vectors, 2 for distributions over matrices, etc. """ + bijective = True def __init__(self, loc, scale, event_dim=0, cache_size=0): @@ -822,6 +835,7 @@ class CorrCholeskyTransform(Transform): - Applies :math:`s_i = StickBreakingTransform(z_i)`. - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`. """ + domain = constraints.real_vector codomain = constraints.corr_cholesky bijective = True @@ -897,6 +911,7 @@ class SoftmaxTransform(Transform): coordinate-wise (except for the final normalization), and thus is appropriate for coordinate-wise optimization algorithms. """ + domain = constraints.real_vector codomain = constraints.simplex diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index f4b07f3fe9c1..31007c924de0 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -26,6 +26,7 @@ class Uniform(Distribution): low (float or Tensor): lower range (inclusive). high (float or Tensor): upper range (exclusive). """ + # TODO allow (loc,scale) parameterization to allow independent constraints. arg_constraints = { "low": constraints.dependent(is_discrete=False, event_dim=0), diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 1954e1bf357d..f83d75c904ab 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -151,12 +151,10 @@ class lazy_property(Generic[T, R]): @overload def __get__( self, instance: None, obj_type: Any = None - ) -> "_lazy_property_and_property[T, R]": - ... + ) -> "_lazy_property_and_property[T, R]": ... @overload - def __get__(self, instance: T, obj_type: Any = None) -> R: - ... + def __get__(self, instance: T, obj_type: Any = None) -> R: ... def __get__( self, instance: Union[T, None], obj_type: Any = None diff --git a/torch/distributions/weibull.py b/torch/distributions/weibull.py index 9d376fe11064..e7b3c5e0cebe 100644 --- a/torch/distributions/weibull.py +++ b/torch/distributions/weibull.py @@ -27,6 +27,7 @@ class Weibull(TransformedDistribution): scale (float or Tensor): Scale parameter of distribution (lambda). concentration (float or Tensor): Concentration parameter of distribution (k/shape). """ + arg_constraints = { "scale": constraints.positive, "concentration": constraints.positive, diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index b234c3c21a03..225aeeb97430 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -39,7 +39,7 @@ class Wishart(ExponentialFamily): >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional") >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) >>> m.sample() # Wishart distributed with mean=`df * I` and - >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j + >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j Args: df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1 @@ -63,6 +63,7 @@ class Wishart(ExponentialFamily): [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203. [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`. """ + arg_constraints = { "covariance_matrix": constraints.positive_definite, "precision_matrix": constraints.positive_definite, @@ -83,7 +84,9 @@ class Wishart(ExponentialFamily): ): assert (covariance_matrix is not None) + (scale_tril is not None) + ( precision_matrix is not None - ) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." + ) == 1, ( + "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." + ) param = next( p