[Expanded Weights] fix loss reduction (#80892)

Two changes in here:
(1) Changes `call_for_per_sample_grads` to be curried. Old call looks like:
`call_for_per_sample_grads(module, batch_size, args, kwargs)`
New call looks like:
`call_for_per_sample_grads(module, batch_size, loss_reduction=loss_reduction)(args, kwargs)`

(2) Adds the ability to specify a loss reduction, to match what is done in Opacus. Opacus has a more complete explanation but essentially, they want the per sample gradient behavior to match what is happens in a for loop with a single example. This gets messed up if you use a mean reduction at the end since in a batch that ends up scaling all the grad_outputs by 1/batch_size, so we offset that by scaling all the grad_samples by batch_size if the loss_reduction is mean
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80892
Approved by: https://github.com/zou3519
This commit is contained in:
samdow
2022-07-18 13:19:16 -04:00
committed by PyTorch MergeBot
parent 71e16f9eef
commit 799bc645d9
4 changed files with 155 additions and 72 deletions

View File

@ -27,8 +27,12 @@ class TestExpandedWeightHelperFunction(TestCase):
weight = torch.randn(5, 4, device=device)
bias = torch.randn(5, device=device)
for (weight_batched, bias_batched) in product([True, False], [True, False]):
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3) if weight_batched else weight
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3) if bias_batched else bias
maybe_batched_weight = weight
maybe_batched_bias = bias
if weight_batched:
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3, loss_reduction="sum")
if bias_batched:
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3, loss_reduction="sum")
args = (input, maybe_batched_weight, maybe_batched_bias)
expanded_args, expanded_kwargs = standard_kwargs(('bias',), args)
res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
@ -45,7 +49,7 @@ class TestExpandedWeightHelperFunction(TestCase):
weight = torch.randn(5, 4, device=device)
bias = torch.randn(5, device=device)
with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."):
input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3)
input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum")
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias))
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"):
@ -61,18 +65,22 @@ class TestExpandedWeightHelperFunction(TestCase):
for (weight_batched, bias_batched) in product([True, False], [True, False]):
if not weight_batched and not bias_batched:
continue
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4) if weight_batched else weight
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4) if bias_batched else bias
maybe_batched_weight = weight
maybe_batched_bias = bias
if weight_batched:
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4, loss_reduction="sum")
if bias_batched:
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4, loss_reduction="sum")
with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"):
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias))
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
def test_set_grad_sample_if_exists(self, device):
def test_fn(_):
def test_fn(a):
return True
orig_weight = torch.randn(4, device=device, requires_grad=True)
expanded_weight = ExpandedWeight(orig_weight, 3)
expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
set_grad_sample_if_exists(expanded_weight, test_fn)
self.assertTrue(hasattr(orig_weight, 'grad_sample'))
self.assertTrue(orig_weight.grad_sample)
@ -86,7 +94,7 @@ class TestExpandedWeightHelperFunction(TestCase):
self.assertFalse(hasattr(non_tensor, 'grad_sample'))
def test_set_grad_sample_if_exists_failure(self, device):
def test_fn(_):
def test_fn(a):
return True
grad_tensor = torch.randn(4, requires_grad=True, device=device)
@ -95,7 +103,7 @@ class TestExpandedWeightHelperFunction(TestCase):
def test_unpack_expanded_weight_or_tensor(self, device):
input = torch.randn(3, requires_grad=True, device=device)
self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3)))
self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum")))
input.requires_grad_(False)
self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
@ -103,7 +111,7 @@ class TestExpandedWeightHelperFunction(TestCase):
def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
input = torch.randn(3, requires_grad=True, device=device)
self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3), lambda x: x is input))
self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input))
input.requires_grad_(False)
self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
@ -131,38 +139,58 @@ class TestExpandedWeightHelperFunction(TestCase):
self.assertEqual(res, input)
class TestExpandedWeightFunctional(TestCase):
def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction):
input = sample_input.input
args = sample_input.args
kwargs = sample_input.kwargs
batch_size = input.shape[0] if len(input.shape) > 1 else 1
# get per sample grads with ExpandedWeights objects
loss_reduction = "sum" if reduction == torch.sum else "mean"
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction)
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list]
if not diff_input_list:
return
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
reduction(result).backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list)
# get per sample grads with for loop
func = partial(run_op, op)
per_sample_grad = for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs)
# check equality
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
if loss_reduction == "mean":
# don't check equality of `input.grad`s since these vanilla tensors won't be scaled
expanded_weight_grad = expanded_weight_grad[1:]
per_sample_grad = per_sample_grad[1:]
for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad):
self.assertEqual(result_grad, expected_grad)
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
def test_expanded_weight_per_sample_grad(self, device, dtype, op):
def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op):
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
for sample_input in supported_inputs(op, sample_inputs):
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
input = sample_input.input
args = sample_input.args
kwargs = sample_input.kwargs
batch_size = input.shape[0] if len(input.shape) > 1 else 1
# get per sample grads with ExpandedWeights objects
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list]
if not diff_input_list:
continue
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list)
def reduction(x):
return x.sum()
# get per sample grads with for loop
func = partial(run_op, op)
per_sample_grad = for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs)
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
# check equality
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad):
if result_grad is None:
result_grad = torch.zeros_like(expected_grad)
self.assertEqual(result_grad, expected_grad)
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op):
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
for sample_input in supported_inputs(op, sample_inputs):
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
def test_unsupported_expand_weights(self, device, dtype, op):
@ -195,34 +223,35 @@ class TestExpandedWeightFunctional(TestCase):
if "cuda" in device and "max_norm" in sample_input.kwargs and "padding_idx" in sample_input.kwargs:
self.skipTest("embedding is non-determinstic in this case, see issue #74679")
batch_size = sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs)
self.assertEqual(expanded_weight_result, normal_result)
for loss_reduction in ["sum", "mean"]:
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction)
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs)
self.assertEqual(expanded_weight_result, normal_result)
def test_expanded_weight_error(self, device):
batch_size = 3
sample_input = make_tensor((batch_size, 4), dtype=torch.float32, device=device, requires_grad=True)
sample_weight = make_tensor((4), dtype=torch.float32, device=device, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, r"Expanded Weights encountered but cannot handle function"):
torch.add(sample_input, ExpandedWeight(sample_weight, batch_size))
torch.add(sample_input, ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"))
def _test_embedding_model(self, model, num_embedding, device):
batch_size = 32
input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device)
return self._test_model(partial(model, num_embedding=num_embedding), batch_size, input, device)
def _test_conv_model(self, model, input_size, num_dim, device):
def _test_conv_model(self, model, input_size, num_dim, device, loss_reduction="sum"):
batch_size = 32
input_ending = [input_size] * num_dim
input = torch.randn([batch_size, 3] + input_ending, device=device)
return self._test_model(partial(model, num_dim=num_dim), batch_size, input, device)
return self._test_model(partial(model, num_dim=num_dim), batch_size, input, device, loss_reduction)
def _test_model(self, model, batch_size, input, device):
def _test_model(self, model, batch_size, input, device, loss_reduction="sum"):
model = model(10).to(device)
targets = torch.randint(0, 10, (batch_size,), device=device)
criterion = CrossEntropyLoss(reduction='sum') # use a loss that doesn't average across the batch to test in a for loop
result = call_for_per_sample_grads(model, batch_size, input)
criterion = CrossEntropyLoss(reduction=loss_reduction)
result = call_for_per_sample_grads(model, batch_size, loss_reduction=loss_reduction)(input)
loss = criterion(result, targets)
loss.backward()
result = []
@ -240,7 +269,7 @@ class TestExpandedWeightFunctional(TestCase):
self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
def test_cnn_model(self, device):
def test_cnn_model_sum(self, device):
def convnet(num_classes, num_dim):
return nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
@ -261,6 +290,27 @@ class TestExpandedWeightFunctional(TestCase):
return self._test_conv_model(convnet, 28, 2, device)
def test_cnn_model_mean(self, device):
def convnet(num_classes, num_dim):
return nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(start_dim=1, end_dim=-1),
nn.Linear(128, num_classes, bias=True),
)
return self._test_conv_model(convnet, 28, 2, device, loss_reduction="mean")
@parametrize('num_dim', [1, 2, 3])
def test_instance_norm_model(self, num_dim, device):
def instance_norm_model(num_classes, num_dim):
@ -326,7 +376,7 @@ class TestExpandedWeightModule(TestCase):
input.requires_grad_()
with freeze_rng_state():
# get per sample grads with ExpandedWeights context manager
actual_res = call_for_per_sample_grads(module, batch_size, input).sum()
actual_res = call_for_per_sample_grads(module, batch_size, loss_reduction="sum")(input).sum()
actual_res.backward()
actual_grads = []
for param in module.parameters():
@ -368,7 +418,7 @@ class TestExpandedWeightModule(TestCase):
with freeze_rng_state():
# get per sample grads with ExpandedWeights context manager, calling .backward() twice
test_module = TestModule(module)
actual_res = call_for_per_sample_grads(test_module, batch_size, input).sum()
actual_res = call_for_per_sample_grads(test_module, batch_size, loss_reduction="sum")(input).sum()
actual_res.backward()
actual_grads = []
for param in module.parameters():
@ -397,15 +447,17 @@ class TestExpandedWeightModule(TestCase):
module = nn.Linear(10, 10)
input = torch.randn(64, 10)
with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
call_for_per_sample_grads("fail", 64, input)
call_for_per_sample_grads("fail", 64)(input)
with self.assertRaisesRegex(RuntimeError, r"Batch size passed must be an integer"):
call_for_per_sample_grads(module, 6.4, input)
call_for_per_sample_grads(module, 6.4)(input)
with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
call_for_per_sample_grads(module, -64, input)
call_for_per_sample_grads(module, -64)(input)
with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
loss = call_for_per_sample_grads(module, 64, input).sum()
loss = call_for_per_sample_grads(module, 64)(input).sum()
loss.backward() # populate grad_sample fields
call_for_per_sample_grads(module, 64, input)
call_for_per_sample_grads(module, 64)(input)
with self.assertRaisesRegex(RuntimeError, r"Expected loss_reduction argument to be sum or mean"):
call_for_per_sample_grads(module, -64, loss_reduction="")(input)
class ContextManagerTests(TestBase):
def __init__(self, *args, **kwargs):
@ -478,9 +530,11 @@ def run_op(op, input, *args, **kwargs):
else:
return op(input, *args, **kwargs)
def make_expanded_weight(sample_input, batch_size):
def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"):
def expanded_weight_or_clone(arg):
return ExpandedWeight(torch.clone(arg), batch_size) if is_diff_tensor(arg) else clone_if_tensor(arg)
if is_diff_tensor(arg):
return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
return clone_if_tensor(arg)
ew_input = clone_if_tensor(sample_input.input)
ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
@ -513,12 +567,12 @@ def supported_inputs(op, sample_inputs, supported_inputs=True):
return is_supported_input if supported_inputs else not is_supported_input
return [input for input in sample_inputs if filter_fn(input)]
def for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs):
def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs):
# get per sample grads by getting derivative for each input in a for loop
per_sample_grad = []
for i in range(batch_size):
per_sample_input = input[i]
result = func(per_sample_input.unsqueeze(0), *args, **kwargs)
result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs))
diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
diff_input_list = [i for i in diff_input_list if isinstance(i, torch.Tensor) and i.requires_grad]
per_sample_grad.append(torch.autograd.grad(result, diff_input_list, torch.ones_like(result), allow_unused=True))

View File

@ -26,13 +26,14 @@ def implements_per_sample_grads(torch_function):
#
# Needs to be a tensor subclass to allow reparamaterization
class ExpandedWeight(torch.Tensor):
def __init__(self, orig_weight, batch_size):
def __init__(self, orig_weight, batch_size, loss_reduction):
self.batch_size = batch_size
self.orig_weight = orig_weight
self.loss_reduction = loss_reduction
handled_functions = HANDLED_FUNCTIONS
def __new__(cls, orig_weight, _):
def __new__(cls, orig_weight, batch_size, loss_reduction):
if not isinstance(orig_weight, torch.Tensor):
raise RuntimeError(f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}")
if not orig_weight.requires_grad:

View File

@ -1,3 +1,5 @@
from typing import Optional
import torch
from .expanded_weights_impl import ExpandedWeight
@ -52,18 +54,34 @@ def _check_and_unexpand_args(func, expanded_args, expanded_kwargs):
raise RuntimeError("Expected ExpandedWeights to have batch size matching input but got "
f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}")
loss_reduction: Optional[str] = None
for arg in expanded_args + tuple(expanded_kwargs.values()):
if isinstance(arg, ExpandedWeight):
if loss_reduction is None:
loss_reduction = arg.loss_reduction
elif loss_reduction != arg.loss_reduction:
raise RuntimeError("Expected ExpandedWeights to all have the same loss_reduction argument but got one"
f"with {loss_reduction} and one with {arg.loss_reduction}")
unexpanded_args = tuple(arg.orig_weight if isinstance(arg, ExpandedWeight) else arg for arg in expanded_args)
unexpanded_kwargs = {name: arg.orig_weight if isinstance(arg, ExpandedWeight) else arg
for (name, arg) in expanded_kwargs.items()}
return unexpanded_args, unexpanded_kwargs
def maybe_scale_by_batch_size(grad_sample, expanded_weight):
if expanded_weight.loss_reduction == "mean":
return grad_sample * expanded_weight.batch_size
else:
return grad_sample
def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn):
unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight)
if isinstance(maybe_expanded_weight, ExpandedWeight):
grad_sample_contribution = maybe_scale_by_batch_size(per_sample_grad_fn(unpacked), maybe_expanded_weight)
if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None:
unpacked.grad_sample = unpacked.grad_sample + per_sample_grad_fn(unpacked)
unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution
else:
unpacked.grad_sample = per_sample_grad_fn(unpacked)
unpacked.grad_sample = grad_sample_contribution
def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x):
if isinstance(maybe_expanded_weight, ExpandedWeight):

View File

@ -1,29 +1,31 @@
import functools
import torch
from torch.nn.utils._stateless import functional_call
from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight
# dependency on `functional_call` means that this can't be exposed in utils
# without creating circular dependency
def call_for_per_sample_grads(module, batch_size, args, kwargs=None):
def call_for_per_sample_grads(module, batch_size, loss_reduction="sum"):
r"""
call_for_per_sample_grads(module, batch_size, args, kwargs=None) -> Tensor
Invoked just like a forward pass, ``call_for_per_sample_grads`` will produce the same
forward result. Then, when backward is invoked, the parameters of ``module``
will have a ``grad_sample`` field populated with the per sample gradients
instead of the regular gradients
call_for_per_sample_grads(module, batch_size, loss_reduction="sum")
``call_for_per_sample_grads`` returns a function that is invoked like the forward
function of ``module`` and will produce the same result. Then, when backward is invoked,
the parameters of ``module`` will have a ``grad_sample`` field populated with the per sample
gradients instead of the regular gradients
Args:
module: The ``nn.Module`` to get per sample gradients with respect to. All trainable
parameters will compute per sample gradients, located in a ``grad_sample``
field when ``backward`` is invoked
batch_size: The batch size of the input. Typically the input's first dimension
args: Tuple of positional args passed to ``module`` to perform the forward pass
kwargs: Dict of named args passed to ``module`` to perform the forward pass. Default: None
loss_reduction: The reduction used on the loss. If "mean", per sample gradients will be scaled by the batch size
to offset the crossbatch interaction from running mean across a batch. Must be "mean" or "sum". Default: "sum"
Examples::
>>> model = nn.Linear(4, 3)
>>> batched_input = torch.randn(5, 4) # batch size of 5
>>> res = call_for_per_sample_grads(model, batched_input.shape[0], batched_input).sum()
>>> res = call_for_per_sample_grads(model, batched_input.shape[0])(batched_input).sum()
>>> res.backward()
>>> assert model.weight.shape == (3, 4)
>>> assert model.weight.grad_sample.shape == (5, 3, 4)
@ -36,12 +38,16 @@ def call_for_per_sample_grads(module, batch_size, args, kwargs=None):
Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom
rewrites that wrap an `nn.Linear` module. See Opacus for an example
"""
def maybe_build_expanded_weight(og_tensor):
if og_tensor.requires_grad:
return ExpandedWeight(og_tensor, batch_size)
return ExpandedWeight(og_tensor, batch_size, loss_reduction)
else:
return og_tensor
if loss_reduction not in ["sum", "mean"]:
raise RuntimeError(f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}")
if not isinstance(module, torch.nn.Module):
raise RuntimeError(f"Module passed must be nn.Module, got {type(module).__name__}")
if not isinstance(batch_size, int):
@ -53,5 +59,9 @@ def call_for_per_sample_grads(module, batch_size, args, kwargs=None):
raise RuntimeError("Current Expanded Weights accumulates the gradients, which will be incorrect for multiple "
f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or "
"post an issue to pytorch/pytorch to prioritize correct behavior")
params = {name: maybe_build_expanded_weight(value) for (name, value) in module.named_parameters()}
return functional_call(module, params, args, kwargs)
@functools.wraps(module.forward)
def wrapper(*args, **kwargs):
params = {name: maybe_build_expanded_weight(value) for (name, value) in module.named_parameters()}
return functional_call(module, params, args, kwargs)
return wrapper