Files
pytorch/torch/nn/utils/_per_sample_grad.py
NVS Abhilash db66f15785 docs: fix docstrings in distributed.py and others (fixes #112604) (#112657)
Fixes #112604

Fixes docstring by following `pydocstyle` outputs.

- torch/nn/parallel/distributed.py
Before: 84
```
torch/nn/parallel/distributed.py:1 at module level:
        D100: Missing docstring in public module
torch/nn/parallel/distributed.py:92 in private function `_cast_buffers`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/parallel/distributed.py:103 in private function `_setup_mixed_precision_params`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/parallel/distributed.py:103 in private function `_setup_mixed_precision_params`:
        D401: First line should be in imperative mood (perhaps 'Create', not 'Creates')
torch/nn/parallel/distributed.py:143 in private function `_find_tensors`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/parallel/distributed.py:273 in private method `__init__`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/parallel/distributed.py:273 in private method `__init__`:
        D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
torch/nn/parallel/distributed.py:287 in private method `main_hook`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:287 in private method `main_hook`:
        D400: First line should end with a period (not 'd')
torch/nn/parallel/distributed.py:324 in private method `post_hook`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:324 in private method `post_hook`:
        D400: First line should end with a period (not 'l')
torch/nn/parallel/distributed.py:324 in private method `post_hook`:
        D401: First line should be in imperative mood (perhaps 'Sync', not 'Syncs')
torch/nn/parallel/distributed.py:332 in public class `DistributedDataParallel`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:332 in public class `DistributedDataParallel`:
        D400: First line should end with a period (not 'n')
torch/nn/parallel/distributed.py:633 in public method `__init__`:
        D107: Missing docstring in __init__
torch/nn/parallel/distributed.py:960 in private method `_fire_reducer_autograd_hook`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:960 in private method `_fire_reducer_autograd_hook`:
        D401: First line should be in imperative mood (perhaps 'Fire', not 'Fires')
torch/nn/parallel/distributed.py:969 in private method `_root_copy_hook`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:969 in private method `_root_copy_hook`:
        D400: First line should end with a period (not 's')
torch/nn/parallel/distributed.py:1012 in private method `_module_wait_for_copy_hook`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:1012 in private method `_module_wait_for_copy_hook`:
        D400: First line should end with a period (not 'e')
torch/nn/parallel/distributed.py:1050 in private method `_ddp_init_helper`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:1050 in private method `_ddp_init_helper`:
        D400: First line should end with a period (not ':')
torch/nn/parallel/distributed.py:1050 in private method `_ddp_init_helper`:
        D401: First line should be in imperative mood (perhaps 'Initialize', not 'Initialization')
torch/nn/parallel/distributed.py:1146 in public method `__getstate__`:
        D105: Missing docstring in magic method
torch/nn/parallel/distributed.py:1154 in public method `__setstate__`:
        D105: Missing docstring in magic method
torch/nn/parallel/distributed.py:1222 in private method `_assign_modules_buffers`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:1222 in private method `_assign_modules_buffers`:
        D400: First line should end with a period (not 'o')
torch/nn/parallel/distributed.py:1222 in private method `_assign_modules_buffers`:
        D401: First line should be in imperative mood (perhaps 'Assign', not 'Assigns')
torch/nn/parallel/distributed.py:1277 in private method `_get_parameters`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/parallel/distributed.py:1277 in private method `_get_parameters`:
        D400: First line should end with a period (not 's')
torch/nn/parallel/distributed.py:1277 in private method `_get_parameters`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
torch/nn/parallel/distributed.py:1312 in public method `no_sync`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:1312 in public method `no_sync`:
        D400: First line should end with a period (not 'P')
torch/nn/parallel/distributed.py:1312 in public method `no_sync`:
        D401: First line should be in imperative mood; try rephrasing (found 'A')
torch/nn/parallel/distributed.py:1340 in private method `_get_active_ddp_module`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/parallel/distributed.py:1340 in private method `_get_active_ddp_module`:
        D403: First word of the first line should be properly capitalized ('Torchdynamo', not 'TorchDynamo')
torch/nn/parallel/distributed.py:1517 in public method `forward`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1527 in public method `scatter`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1530 in public method `to_kwargs`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1539 in public method `gather`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1542 in public method `train`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1617 in public method `join`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:1617 in public method `join`:
        D400: First line should end with a period (not 'f')
torch/nn/parallel/distributed.py:1617 in public method `join`:
        D401: First line should be in imperative mood; try rephrasing (found 'A')
torch/nn/parallel/distributed.py:1723 in public method `join_hook`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:1723 in public method `join_hook`:
        D400: First line should end with a period (not 'y')
torch/nn/parallel/distributed.py:1723 in public method `join_hook`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
torch/nn/parallel/distributed.py:1752 in public method `join_device`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1756 in public method `join_process_group`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1765 in private method `_register_buffer_comm_hook`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:1765 in private method `_register_buffer_comm_hook`:
        D400: First line should end with a period (not 'e')
torch/nn/parallel/distributed.py:1765 in private method `_register_buffer_comm_hook`:
        D401: First line should be in imperative mood (perhaps 'Allow', not 'Allows')
torch/nn/parallel/distributed.py:1805 in public method `register_comm_hook`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:1805 in public method `register_comm_hook`:
        D400: First line should end with a period (not 'a')
torch/nn/parallel/distributed.py:1805 in public method `register_comm_hook`:
        D401: First line should be in imperative mood (perhaps 'Register', not 'Registers')
torch/nn/parallel/distributed.py:1887 in private method `_register_builtin_comm_hook`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:1887 in private method `_register_builtin_comm_hook`:
        D400: First line should end with a period (not 'P')
torch/nn/parallel/distributed.py:1887 in private method `_register_builtin_comm_hook`:
        D401: First line should be in imperative mood (perhaps 'Register', not 'Registers')
torch/nn/parallel/distributed.py:1914 in private method `_register_fused_optim`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:1914 in private method `_register_fused_optim`:
        D400: First line should end with a period (not 'a')
torch/nn/parallel/distributed.py:1914 in private method `_register_fused_optim`:
        D401: First line should be in imperative mood (perhaps 'Register', not 'Registers')
torch/nn/parallel/distributed.py:2005 in public method `will_sync_module_buffers`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:2060 in private method `_default_broadcast_coalesced`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:2060 in private method `_default_broadcast_coalesced`:
        D400: First line should end with a period (not 'e')
torch/nn/parallel/distributed.py:2128 in private method `_get_data_parallel_params`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/parallel/distributed.py:2128 in private method `_get_data_parallel_params`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
torch/nn/parallel/distributed.py:2141 in private method `_set_params_and_buffers_to_ignore_for_model`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:2141 in private method `_set_params_and_buffers_to_ignore_for_model`:
        D400: First line should end with a period (not 'r')
torch/nn/parallel/distributed.py:2141 in private method `_set_params_and_buffers_to_ignore_for_model`:
        D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
torch/nn/parallel/distributed.py:2170 in private method `_get_ddp_logging_data`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:2170 in private method `_get_ddp_logging_data`:
        D400: First line should end with a period (not 's')
torch/nn/parallel/distributed.py:2170 in private method `_get_ddp_logging_data`:
        D401: First line should be in imperative mood; try rephrasing (found 'This')
torch/nn/parallel/distributed.py:2184 in private method `_set_ddp_runtime_logging_sample_rate`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:2184 in private method `_set_ddp_runtime_logging_sample_rate`:
        D400: First line should end with a period (not 'g')
torch/nn/parallel/distributed.py:2184 in private method `_set_ddp_runtime_logging_sample_rate`:
        D401: First line should be in imperative mood; try rephrasing (found 'This')
torch/nn/parallel/distributed.py:2202 in private method `_set_static_graph`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:2202 in private method `_set_static_graph`:
        D400: First line should end with a period (not 'l')
torch/nn/parallel/distributed.py:2202 in private method `_set_static_graph`:
        D401: First line should be in imperative mood; try rephrasing (found 'It')
torch/nn/parallel/distributed.py:2227 in private method `_remove_autograd_hooks`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/parallel/distributed.py:2227 in private method `_remove_autograd_hooks`:
        D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes')
torch/nn/parallel/distributed.py:2233 in private method `_check_reducer_finalized`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/parallel/distributed.py:2233 in private method `_check_reducer_finalized`:
        D400: First line should end with a period (not 'd')
torch/nn/parallel/distributed.py:2233 in private method `_check_reducer_finalized`:
        D401: First line should be in imperative mood (perhaps 'Check', not 'Checks')
84
```

After: 12
```
torch/nn/parallel/distributed.py:1 at module level:
        D100: Missing docstring in public module
torch/nn/parallel/distributed.py:618 in public method `__init__`:
        D107: Missing docstring in __init__
torch/nn/parallel/distributed.py:1133 in public method `__getstate__`:
        D105: Missing docstring in magic method
torch/nn/parallel/distributed.py:1141 in public method `__setstate__`:
        D105: Missing docstring in magic method
torch/nn/parallel/distributed.py:1503 in public method `forward`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1513 in public method `scatter`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1516 in public method `to_kwargs`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1525 in public method `gather`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1528 in public method `train`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1734 in public method `join_device`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1738 in public method `join_process_group`:
        D102: Missing docstring in public method
torch/nn/parallel/distributed.py:1986 in public method `will_sync_module_buffers`:
        D102: Missing docstring in public method
12
```

- torch/nn/utils/_named_member_accessor.py
Before: 23
```
torch/nn/utils/_named_member_accessor.py:12 in public function `set_tensor`:
        D103: Missing docstring in public function
torch/nn/utils/_named_member_accessor.py:29 in public function `swap_tensor`:
        D103: Missing docstring in public function
torch/nn/utils/_named_member_accessor.py:85 in public function `swap_submodule`:
        D103: Missing docstring in public function
torch/nn/utils/_named_member_accessor.py:109 in public class `NamedMemberAccessor`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:109 in public class `NamedMemberAccessor`:
        D400: First line should end with a period (not 's')
torch/nn/utils/_named_member_accessor.py:115 in public method `__init__`:
        D107: Missing docstring in __init__
torch/nn/utils/_named_member_accessor.py:122 in public method `get_submodule`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:155 in public method `swap_submodule`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:164 in public method `get_tensor`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:185 in public method `set_tensor`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:194 in public method `del_tensor`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:211 in public method `swap_tensor`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:224 in public method `get_tensors`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:233 in public method `set_tensors`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:249 in public method `set_tensors_dict`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:261 in public method `del_tensors`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:276 in public method `swap_tensors`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:296 in public method `swap_tensors_dict`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_named_member_accessor.py:325 in public method `check_keys`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/utils/_named_member_accessor.py:340 in public method `named_parameters`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/utils/_named_member_accessor.py:349 in public method `named_buffers`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/utils/_named_member_accessor.py:358 in public method `named_tensors`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/nn/utils/_named_member_accessor.py:368 in public method `named_modules`:
        D200: One-line docstring should fit on one line with quotes (found 3)
23
```

After: 4
```
torch/nn/utils/_named_member_accessor.py:12 in public function `set_tensor`:
        D103: Missing docstring in public function
torch/nn/utils/_named_member_accessor.py:29 in public function `swap_tensor`:
        D103: Missing docstring in public function
torch/nn/utils/_named_member_accessor.py:85 in public function `swap_submodule`:
        D103: Missing docstring in public function
torch/nn/utils/_named_member_accessor.py:116 in public method `__init__`:
        D107: Missing docstring in __init__
4
```

- torch/nn/utils/_per_sample_grad.py
Before: 3
```
torch/nn/utils/_per_sample_grad.py:12 in public function `call_for_per_sample_grads`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/_per_sample_grad.py:12 in public function `call_for_per_sample_grads`:
        D400: First line should end with a period (not ')')
torch/nn/utils/_per_sample_grad.py:12 in public function `call_for_per_sample_grads`:
        D402: First line should not be the function's "signature"
3
```
After: 0
```
0
```

- torch/nn/utils/init.py
Before: 3
```
torch/nn/utils/init.py:1 at module level:
        D100: Missing docstring in public module
torch/nn/utils/init.py:6 in public function `skip_init`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/init.py:6 in public function `skip_init`:
        D400: First line should end with a period (not 'g')
3
```
After: 1
```
torch/nn/utils/init.py:1 at module level:
        D100: Missing docstring in public module
1
```

- torch/nn/utils/memory_format.py
Before: 4
```
torch/nn/utils/memory_format.py:1 at module level:
        D100: Missing docstring in public module
torch/nn/utils/memory_format.py:5 in public function `convert_conv2d_weight_memory_format`:
        D202: No blank lines allowed after function docstring (found 1)
torch/nn/utils/memory_format.py:5 in public function `convert_conv2d_weight_memory_format`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/utils/memory_format.py:5 in public function `convert_conv2d_weight_memory_format`:
        D400: First line should end with a period (not '`')
4
```
After: 1
```
torch/nn/utils/memory_format.py:1 at module level:
        D100: Missing docstring in public module
1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112657
Approved by: https://github.com/fduwjj
2023-11-02 05:52:47 +00:00

103 lines
5.5 KiB
Python

import functools
import torch
from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight
from torch.utils import _pytree as pytree
# dependency on `functional_call` means that this can't be exposed in utils
# without creating circular dependency
def call_for_per_sample_grads(module, *, batch_size=None, loss_reduction="sum", batch_first=True):
r"""
Return a forward function for a module, populating grad_sample with per sample gradients on backward invocation.
Args:
module: The ``nn.Module`` to get per sample gradients with respect to. All trainable
parameters will compute per sample gradients, located in a ``grad_sample``
field when ``backward`` is invoked
batch_size: The batch size of the input. If None is passed, all tensor arguments in args and kwargs must have
the same batch size, which is the size of the first dimension. Otherwise, it must be passed manually.
Default: None
loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If
"mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from
running mean across a batch. Must be "mean" or "sum". Default: "sum"
batch_first: Indicates if the batch dimension is the first dimension. If True, the batch dimension is the first
dimension. If False, it's the second dimension. Default: True.
Examples::
>>> # xdoctest: +SKIP
>>> model = nn.Linear(4, 3)
>>> batched_input = torch.randn(5, 4) # batch size of 5
>>> res = call_for_per_sample_grads(model)(batched_input).sum()
>>> res.backward()
>>> assert model.weight.shape == (3, 4)
>>> assert model.weight.grad_sample.shape == (5, 3, 4)
>>> assert model.weight.grad is None
>>> assert model.bias.shape == (3,)
>>> assert model.bias.grad_sample.shape == (5, 3)
>>> assert model.bias.grad is None
An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be
if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all
grad_outputs by 1 / batch_size from cross batch interaction.
>>> model = nn.Linear(4, 3)
>>> batched_input = torch.randn(5, 4) # batch size of 5
>>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(batched_input).mean()
>>> res.backward()
Note::
Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom
rewrites that wrap an `nn.Linear` module. See Opacus for an example
"""
def maybe_build_expanded_weight(og_tensor, batch_size):
if og_tensor.requires_grad:
return ExpandedWeight(og_tensor, batch_size, loss_reduction)
else:
return og_tensor
def compute_batch_size(*args, **kwargs):
args_and_kwargs = pytree.arg_tree_leaves(*args, **kwargs)
batch_size = None
for arg in args_and_kwargs:
if not isinstance(arg, torch.Tensor):
continue
arg_batch_size = arg.shape[0] if batch_first else arg.shape[1]
if batch_size is not None and batch_size != arg_batch_size:
raise RuntimeError("When computing batch size, found at least one input with batch size "
f"{batch_size} and one with batch size {arg_batch_size}. Please specify it "
"explicitly using the batch size kwarg in call_for_per_sample_grads")
batch_size = arg_batch_size
if batch_size is None:
raise RuntimeError("Unable to find a tensor in the passed args and kwargs. They may not be pytree-able "
"and so ExpandedWeights cannot compute the batch size from the inputs. Please specify "
"it explicitly")
return batch_size
if loss_reduction not in ["sum", "mean"]:
raise RuntimeError(f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}")
if not isinstance(module, torch.nn.Module):
raise RuntimeError(f"Module passed must be nn.Module, got {type(module).__name__}")
if not (batch_size is None or isinstance(batch_size, int)):
raise RuntimeError(f"Batch size passed must be None or an integer, got {type(batch_size).__name__}")
if batch_size is not None and batch_size < 1:
raise RuntimeError(f"Batch size must be positive, got {batch_size}")
for weight in module.parameters():
if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined]
raise RuntimeError("Current Expanded Weights accumulates the gradients, which will be incorrect for multiple "
f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or "
"post an issue to pytorch/pytorch to prioritize correct behavior")
@functools.wraps(module.forward)
def wrapper(*args, **kwargs):
wrapper_batch_size = batch_size
if wrapper_batch_size is None:
wrapper_batch_size = compute_batch_size(*args, **kwargs)
params = {name: maybe_build_expanded_weight(value, wrapper_batch_size) for (name, value) in module.named_parameters()}
return torch.func.functional_call(module, params, args, kwargs)
return wrapper