mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70140 [Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights - The rest is the implementation of the erroring fallback + the mechanism for being able to register faster per sample grad rules. Only linear is implemented here, but they are all implemented in #70141 Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D34350950 Pulled By: samdow fbshipit-source-id: 69c664b0bc3dff6951358d79d7e5d94882f7aef2 (cherry picked from commit ae1620d3b6507b27c3bc08ecfb2b1418aa8ce7d7)
58 lines
3.1 KiB
Python
58 lines
3.1 KiB
Python
import torch
|
|
from torch.nn.utils._stateless import functional_call
|
|
from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight
|
|
|
|
# dependency on `functional_call` means that this can't be exposed in utils
|
|
# without creating circular dependency
|
|
def call_for_per_sample_grads(module, batch_size, args, kwargs=None):
|
|
r"""
|
|
call_for_per_sample_grads(module, batch_size, args, kwargs=None) -> Tensor
|
|
Invoked just like a forward pass, ``call_for_per_sample_grads`` will produce the same
|
|
forward result. Then, when backward is invoked, the parameters of ``module``
|
|
will have a ``grad_sample`` field populated with the per sample gradients
|
|
instead of the regular gradients
|
|
|
|
Args:
|
|
module: The ``nn.Module`` to get per sample gradients with respect to. All trainable
|
|
parameters will compute per sample gradients, located in a ``grad_sample``
|
|
field when ``backward`` is invoked
|
|
batch_size: The batch size of the input. Typically the input's first dimension
|
|
args: Tuple of positional args passed to ``module`` to perform the forward pass
|
|
kwargs: Dict of named args passed to ``module`` to perform the forward pass. Default: None
|
|
|
|
Examples::
|
|
>>> model = nn.Linear(4, 3)
|
|
>>> batched_input = torch.randn(5, 4) # batch size of 5
|
|
>>> res = call_for_per_sample_grads(model, batched_input.shape[0], batched_input).sum()
|
|
>>> res.backward()
|
|
>>> assert model.weight.shape == (3, 4)
|
|
>>> assert model.weight.grad_sample.shape == (5, 3, 4)
|
|
>>> assert model.weight.grad == None
|
|
>>> assert model.bias.shape == (3,)
|
|
>>> assert model.bias.grad_sample.shape == (5, 3)
|
|
>>> assert model.bias.grad == None
|
|
|
|
Note::
|
|
Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom
|
|
rewrites that wrap an `nn.Linear` module. See Opacus for an example
|
|
"""
|
|
def maybe_build_expanded_weight(og_tensor):
|
|
if og_tensor.requires_grad:
|
|
return ExpandedWeight(og_tensor, batch_size)
|
|
else:
|
|
return og_tensor
|
|
|
|
if not isinstance(module, torch.nn.Module):
|
|
raise RuntimeError(f"Module passed must be nn.Module, got {type(module).__name__}")
|
|
if not isinstance(batch_size, int):
|
|
raise RuntimeError(f"Batch size passed must be an integer, got {type(batch_size).__name__}")
|
|
if batch_size < 1:
|
|
raise RuntimeError(f"Batch size must be positive, got {batch_size}")
|
|
for weight in module.parameters():
|
|
if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined]
|
|
raise RuntimeError("Current Expanded Weights accumulates the gradients, which will be incorrect for multiple "
|
|
f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or "
|
|
"post an issue to pytorch/pytorch to prioritize correct behavior")
|
|
params = {name: maybe_build_expanded_weight(value) for (name, value) in module.named_parameters()}
|
|
return functional_call(module, params, args, kwargs)
|