mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fixes #112604 Fixes docstring by following `pydocstyle` outputs. - torch/nn/parallel/distributed.py Before: 84 ``` torch/nn/parallel/distributed.py:1 at module level: D100: Missing docstring in public module torch/nn/parallel/distributed.py:92 in private function `_cast_buffers`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/parallel/distributed.py:103 in private function `_setup_mixed_precision_params`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/parallel/distributed.py:103 in private function `_setup_mixed_precision_params`: D401: First line should be in imperative mood (perhaps 'Create', not 'Creates') torch/nn/parallel/distributed.py:143 in private function `_find_tensors`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/parallel/distributed.py:273 in private method `__init__`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/parallel/distributed.py:273 in private method `__init__`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') torch/nn/parallel/distributed.py:287 in private method `main_hook`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:287 in private method `main_hook`: D400: First line should end with a period (not 'd') torch/nn/parallel/distributed.py:324 in private method `post_hook`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:324 in private method `post_hook`: D400: First line should end with a period (not 'l') torch/nn/parallel/distributed.py:324 in private method `post_hook`: D401: First line should be in imperative mood (perhaps 'Sync', not 'Syncs') torch/nn/parallel/distributed.py:332 in public class `DistributedDataParallel`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:332 in public class `DistributedDataParallel`: D400: First line should end with a period (not 'n') torch/nn/parallel/distributed.py:633 in public method `__init__`: D107: Missing docstring in __init__ torch/nn/parallel/distributed.py:960 in private method `_fire_reducer_autograd_hook`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:960 in private method `_fire_reducer_autograd_hook`: D401: First line should be in imperative mood (perhaps 'Fire', not 'Fires') torch/nn/parallel/distributed.py:969 in private method `_root_copy_hook`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:969 in private method `_root_copy_hook`: D400: First line should end with a period (not 's') torch/nn/parallel/distributed.py:1012 in private method `_module_wait_for_copy_hook`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:1012 in private method `_module_wait_for_copy_hook`: D400: First line should end with a period (not 'e') torch/nn/parallel/distributed.py:1050 in private method `_ddp_init_helper`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:1050 in private method `_ddp_init_helper`: D400: First line should end with a period (not ':') torch/nn/parallel/distributed.py:1050 in private method `_ddp_init_helper`: D401: First line should be in imperative mood (perhaps 'Initialize', not 'Initialization') torch/nn/parallel/distributed.py:1146 in public method `__getstate__`: D105: Missing docstring in magic method torch/nn/parallel/distributed.py:1154 in public method `__setstate__`: D105: Missing docstring in magic method torch/nn/parallel/distributed.py:1222 in private method `_assign_modules_buffers`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:1222 in private method `_assign_modules_buffers`: D400: First line should end with a period (not 'o') torch/nn/parallel/distributed.py:1222 in private method `_assign_modules_buffers`: D401: First line should be in imperative mood (perhaps 'Assign', not 'Assigns') torch/nn/parallel/distributed.py:1277 in private method `_get_parameters`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/parallel/distributed.py:1277 in private method `_get_parameters`: D400: First line should end with a period (not 's') torch/nn/parallel/distributed.py:1277 in private method `_get_parameters`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') torch/nn/parallel/distributed.py:1312 in public method `no_sync`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:1312 in public method `no_sync`: D400: First line should end with a period (not 'P') torch/nn/parallel/distributed.py:1312 in public method `no_sync`: D401: First line should be in imperative mood; try rephrasing (found 'A') torch/nn/parallel/distributed.py:1340 in private method `_get_active_ddp_module`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/parallel/distributed.py:1340 in private method `_get_active_ddp_module`: D403: First word of the first line should be properly capitalized ('Torchdynamo', not 'TorchDynamo') torch/nn/parallel/distributed.py:1517 in public method `forward`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1527 in public method `scatter`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1530 in public method `to_kwargs`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1539 in public method `gather`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1542 in public method `train`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1617 in public method `join`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:1617 in public method `join`: D400: First line should end with a period (not 'f') torch/nn/parallel/distributed.py:1617 in public method `join`: D401: First line should be in imperative mood; try rephrasing (found 'A') torch/nn/parallel/distributed.py:1723 in public method `join_hook`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:1723 in public method `join_hook`: D400: First line should end with a period (not 'y') torch/nn/parallel/distributed.py:1723 in public method `join_hook`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') torch/nn/parallel/distributed.py:1752 in public method `join_device`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1756 in public method `join_process_group`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1765 in private method `_register_buffer_comm_hook`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:1765 in private method `_register_buffer_comm_hook`: D400: First line should end with a period (not 'e') torch/nn/parallel/distributed.py:1765 in private method `_register_buffer_comm_hook`: D401: First line should be in imperative mood (perhaps 'Allow', not 'Allows') torch/nn/parallel/distributed.py:1805 in public method `register_comm_hook`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:1805 in public method `register_comm_hook`: D400: First line should end with a period (not 'a') torch/nn/parallel/distributed.py:1805 in public method `register_comm_hook`: D401: First line should be in imperative mood (perhaps 'Register', not 'Registers') torch/nn/parallel/distributed.py:1887 in private method `_register_builtin_comm_hook`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:1887 in private method `_register_builtin_comm_hook`: D400: First line should end with a period (not 'P') torch/nn/parallel/distributed.py:1887 in private method `_register_builtin_comm_hook`: D401: First line should be in imperative mood (perhaps 'Register', not 'Registers') torch/nn/parallel/distributed.py:1914 in private method `_register_fused_optim`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:1914 in private method `_register_fused_optim`: D400: First line should end with a period (not 'a') torch/nn/parallel/distributed.py:1914 in private method `_register_fused_optim`: D401: First line should be in imperative mood (perhaps 'Register', not 'Registers') torch/nn/parallel/distributed.py:2005 in public method `will_sync_module_buffers`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:2060 in private method `_default_broadcast_coalesced`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:2060 in private method `_default_broadcast_coalesced`: D400: First line should end with a period (not 'e') torch/nn/parallel/distributed.py:2128 in private method `_get_data_parallel_params`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/parallel/distributed.py:2128 in private method `_get_data_parallel_params`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') torch/nn/parallel/distributed.py:2141 in private method `_set_params_and_buffers_to_ignore_for_model`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:2141 in private method `_set_params_and_buffers_to_ignore_for_model`: D400: First line should end with a period (not 'r') torch/nn/parallel/distributed.py:2141 in private method `_set_params_and_buffers_to_ignore_for_model`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') torch/nn/parallel/distributed.py:2170 in private method `_get_ddp_logging_data`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:2170 in private method `_get_ddp_logging_data`: D400: First line should end with a period (not 's') torch/nn/parallel/distributed.py:2170 in private method `_get_ddp_logging_data`: D401: First line should be in imperative mood; try rephrasing (found 'This') torch/nn/parallel/distributed.py:2184 in private method `_set_ddp_runtime_logging_sample_rate`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:2184 in private method `_set_ddp_runtime_logging_sample_rate`: D400: First line should end with a period (not 'g') torch/nn/parallel/distributed.py:2184 in private method `_set_ddp_runtime_logging_sample_rate`: D401: First line should be in imperative mood; try rephrasing (found 'This') torch/nn/parallel/distributed.py:2202 in private method `_set_static_graph`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:2202 in private method `_set_static_graph`: D400: First line should end with a period (not 'l') torch/nn/parallel/distributed.py:2202 in private method `_set_static_graph`: D401: First line should be in imperative mood; try rephrasing (found 'It') torch/nn/parallel/distributed.py:2227 in private method `_remove_autograd_hooks`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/parallel/distributed.py:2227 in private method `_remove_autograd_hooks`: D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes') torch/nn/parallel/distributed.py:2233 in private method `_check_reducer_finalized`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parallel/distributed.py:2233 in private method `_check_reducer_finalized`: D400: First line should end with a period (not 'd') torch/nn/parallel/distributed.py:2233 in private method `_check_reducer_finalized`: D401: First line should be in imperative mood (perhaps 'Check', not 'Checks') 84 ``` After: 12 ``` torch/nn/parallel/distributed.py:1 at module level: D100: Missing docstring in public module torch/nn/parallel/distributed.py:618 in public method `__init__`: D107: Missing docstring in __init__ torch/nn/parallel/distributed.py:1133 in public method `__getstate__`: D105: Missing docstring in magic method torch/nn/parallel/distributed.py:1141 in public method `__setstate__`: D105: Missing docstring in magic method torch/nn/parallel/distributed.py:1503 in public method `forward`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1513 in public method `scatter`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1516 in public method `to_kwargs`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1525 in public method `gather`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1528 in public method `train`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1734 in public method `join_device`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1738 in public method `join_process_group`: D102: Missing docstring in public method torch/nn/parallel/distributed.py:1986 in public method `will_sync_module_buffers`: D102: Missing docstring in public method 12 ``` - torch/nn/utils/_named_member_accessor.py Before: 23 ``` torch/nn/utils/_named_member_accessor.py:12 in public function `set_tensor`: D103: Missing docstring in public function torch/nn/utils/_named_member_accessor.py:29 in public function `swap_tensor`: D103: Missing docstring in public function torch/nn/utils/_named_member_accessor.py:85 in public function `swap_submodule`: D103: Missing docstring in public function torch/nn/utils/_named_member_accessor.py:109 in public class `NamedMemberAccessor`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:109 in public class `NamedMemberAccessor`: D400: First line should end with a period (not 's') torch/nn/utils/_named_member_accessor.py:115 in public method `__init__`: D107: Missing docstring in __init__ torch/nn/utils/_named_member_accessor.py:122 in public method `get_submodule`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:155 in public method `swap_submodule`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:164 in public method `get_tensor`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:185 in public method `set_tensor`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:194 in public method `del_tensor`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:211 in public method `swap_tensor`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:224 in public method `get_tensors`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:233 in public method `set_tensors`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:249 in public method `set_tensors_dict`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:261 in public method `del_tensors`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:276 in public method `swap_tensors`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:296 in public method `swap_tensors_dict`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_named_member_accessor.py:325 in public method `check_keys`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/utils/_named_member_accessor.py:340 in public method `named_parameters`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/utils/_named_member_accessor.py:349 in public method `named_buffers`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/utils/_named_member_accessor.py:358 in public method `named_tensors`: D200: One-line docstring should fit on one line with quotes (found 3) torch/nn/utils/_named_member_accessor.py:368 in public method `named_modules`: D200: One-line docstring should fit on one line with quotes (found 3) 23 ``` After: 4 ``` torch/nn/utils/_named_member_accessor.py:12 in public function `set_tensor`: D103: Missing docstring in public function torch/nn/utils/_named_member_accessor.py:29 in public function `swap_tensor`: D103: Missing docstring in public function torch/nn/utils/_named_member_accessor.py:85 in public function `swap_submodule`: D103: Missing docstring in public function torch/nn/utils/_named_member_accessor.py:116 in public method `__init__`: D107: Missing docstring in __init__ 4 ``` - torch/nn/utils/_per_sample_grad.py Before: 3 ``` torch/nn/utils/_per_sample_grad.py:12 in public function `call_for_per_sample_grads`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/_per_sample_grad.py:12 in public function `call_for_per_sample_grads`: D400: First line should end with a period (not ')') torch/nn/utils/_per_sample_grad.py:12 in public function `call_for_per_sample_grads`: D402: First line should not be the function's "signature" 3 ``` After: 0 ``` 0 ``` - torch/nn/utils/init.py Before: 3 ``` torch/nn/utils/init.py:1 at module level: D100: Missing docstring in public module torch/nn/utils/init.py:6 in public function `skip_init`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/init.py:6 in public function `skip_init`: D400: First line should end with a period (not 'g') 3 ``` After: 1 ``` torch/nn/utils/init.py:1 at module level: D100: Missing docstring in public module 1 ``` - torch/nn/utils/memory_format.py Before: 4 ``` torch/nn/utils/memory_format.py:1 at module level: D100: Missing docstring in public module torch/nn/utils/memory_format.py:5 in public function `convert_conv2d_weight_memory_format`: D202: No blank lines allowed after function docstring (found 1) torch/nn/utils/memory_format.py:5 in public function `convert_conv2d_weight_memory_format`: D205: 1 blank line required between summary line and description (found 0) torch/nn/utils/memory_format.py:5 in public function `convert_conv2d_weight_memory_format`: D400: First line should end with a period (not '`') 4 ``` After: 1 ``` torch/nn/utils/memory_format.py:1 at module level: D100: Missing docstring in public module 1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/112657 Approved by: https://github.com/fduwjj
103 lines
5.5 KiB
Python
103 lines
5.5 KiB
Python
import functools
|
|
|
|
import torch
|
|
from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight
|
|
|
|
from torch.utils import _pytree as pytree
|
|
|
|
|
|
# dependency on `functional_call` means that this can't be exposed in utils
|
|
# without creating circular dependency
|
|
def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum", batch_first=True):
|
|
r"""
|
|
Return a forward function for a module, populating grad_sample with per sample gradients on backward invocation.
|
|
|
|
Args:
|
|
module: The ``nn.Module`` to get per sample gradients with respect to. All trainable
|
|
parameters will compute per sample gradients, located in a ``grad_sample``
|
|
field when ``backward`` is invoked
|
|
batch_size: The batch size of the input. If None is passed, all tensor arguments in args and kwargs must have
|
|
the same batch size, which is the size of the first dimension. Otherwise, it must be passed manually.
|
|
Default: None
|
|
loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If
|
|
"mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from
|
|
running mean across a batch. Must be "mean" or "sum". Default: "sum"
|
|
batch_first: Indicates if the batch dimension is the first dimension. If True, the batch dimension is the first
|
|
dimension. If False, it's the second dimension. Default: True.
|
|
|
|
Examples::
|
|
>>> # xdoctest: +SKIP
|
|
>>> model = nn.Linear(4, 3)
|
|
>>> batched_input = torch.randn(5, 4) # batch size of 5
|
|
>>> res = call_for_per_sample_grads(model)(batched_input).sum()
|
|
>>> res.backward()
|
|
>>> assert model.weight.shape == (3, 4)
|
|
>>> assert model.weight.grad_sample.shape == (5, 3, 4)
|
|
>>> assert model.weight.grad is None
|
|
>>> assert model.bias.shape == (3,)
|
|
>>> assert model.bias.grad_sample.shape == (5, 3)
|
|
>>> assert model.bias.grad is None
|
|
|
|
An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be
|
|
if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all
|
|
grad_outputs by 1 / batch_size from cross batch interaction.
|
|
>>> model = nn.Linear(4, 3)
|
|
>>> batched_input = torch.randn(5, 4) # batch size of 5
|
|
>>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(batched_input).mean()
|
|
>>> res.backward()
|
|
|
|
Note::
|
|
Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom
|
|
rewrites that wrap an `nn.Linear` module. See Opacus for an example
|
|
"""
|
|
|
|
def maybe_build_expanded_weight(og_tensor, batch_size):
|
|
if og_tensor.requires_grad:
|
|
return ExpandedWeight(og_tensor, batch_size, loss_reduction)
|
|
else:
|
|
return og_tensor
|
|
|
|
def compute_batch_size(*args, **kwargs):
|
|
args_and_kwargs = pytree.arg_tree_leaves(*args, **kwargs)
|
|
batch_size = None
|
|
for arg in args_and_kwargs:
|
|
if not isinstance(arg, torch.Tensor):
|
|
continue
|
|
|
|
arg_batch_size = arg.shape[0] if batch_first else arg.shape[1]
|
|
if batch_size is not None and batch_size != arg_batch_size:
|
|
raise RuntimeError("When computing batch size, found at least one input with batch size "
|
|
f"{batch_size} and one with batch size {arg_batch_size}. Please specify it "
|
|
"explicitly using the batch size kwarg in call_for_per_sample_grads")
|
|
batch_size = arg_batch_size
|
|
if batch_size is None:
|
|
raise RuntimeError("Unable to find a tensor in the passed args and kwargs. They may not be pytree-able "
|
|
"and so ExpandedWeights cannot compute the batch size from the inputs. Please specify "
|
|
"it explicitly")
|
|
return batch_size
|
|
|
|
if loss_reduction not in ["sum", "mean"]:
|
|
raise RuntimeError(f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}")
|
|
|
|
if not isinstance(module, torch.nn.Module):
|
|
raise RuntimeError(f"Module passed must be nn.Module, got {type(module).__name__}")
|
|
if not (batch_size is None or isinstance(batch_size, int)):
|
|
raise RuntimeError(f"Batch size passed must be None or an integer, got {type(batch_size).__name__}")
|
|
if batch_size is not None and batch_size < 1:
|
|
raise RuntimeError(f"Batch size must be positive, got {batch_size}")
|
|
for weight in module.parameters():
|
|
if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined]
|
|
raise RuntimeError("Current Expanded Weights accumulates the gradients, which will be incorrect for multiple "
|
|
f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or "
|
|
"post an issue to pytorch/pytorch to prioritize correct behavior")
|
|
|
|
@functools.wraps(module.forward)
|
|
def wrapper(*args, **kwargs):
|
|
wrapper_batch_size = batch_size
|
|
if wrapper_batch_size is None:
|
|
wrapper_batch_size = compute_batch_size(*args, **kwargs)
|
|
|
|
params = {name: maybe_build_expanded_weight(value, wrapper_batch_size) for (name, value) in module.named_parameters()}
|
|
return torch.func.functional_call(module, params, args, kwargs)
|
|
return wrapper
|