mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Expanded Weights] fix loss reduction (#80892)
Two changes in here: (1) Changes `call_for_per_sample_grads` to be curried. Old call looks like: `call_for_per_sample_grads(module, batch_size, args, kwargs)` New call looks like: `call_for_per_sample_grads(module, batch_size, loss_reduction=loss_reduction)(args, kwargs)` (2) Adds the ability to specify a loss reduction, to match what is done in Opacus. Opacus has a more complete explanation but essentially, they want the per sample gradient behavior to match what is happens in a for loop with a single example. This gets messed up if you use a mean reduction at the end since in a batch that ends up scaling all the grad_outputs by 1/batch_size, so we offset that by scaling all the grad_samples by batch_size if the loss_reduction is mean Pull Request resolved: https://github.com/pytorch/pytorch/pull/80892 Approved by: https://github.com/zou3519
This commit is contained in:
@ -27,8 +27,12 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
weight = torch.randn(5, 4, device=device)
|
||||
bias = torch.randn(5, device=device)
|
||||
for (weight_batched, bias_batched) in product([True, False], [True, False]):
|
||||
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3) if weight_batched else weight
|
||||
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3) if bias_batched else bias
|
||||
maybe_batched_weight = weight
|
||||
maybe_batched_bias = bias
|
||||
if weight_batched:
|
||||
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3, loss_reduction="sum")
|
||||
if bias_batched:
|
||||
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3, loss_reduction="sum")
|
||||
args = (input, maybe_batched_weight, maybe_batched_bias)
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), args)
|
||||
res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
@ -45,7 +49,7 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
weight = torch.randn(5, 4, device=device)
|
||||
bias = torch.randn(5, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."):
|
||||
input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3)
|
||||
input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum")
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias))
|
||||
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"):
|
||||
@ -61,18 +65,22 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
for (weight_batched, bias_batched) in product([True, False], [True, False]):
|
||||
if not weight_batched and not bias_batched:
|
||||
continue
|
||||
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4) if weight_batched else weight
|
||||
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4) if bias_batched else bias
|
||||
maybe_batched_weight = weight
|
||||
maybe_batched_bias = bias
|
||||
if weight_batched:
|
||||
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4, loss_reduction="sum")
|
||||
if bias_batched:
|
||||
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4, loss_reduction="sum")
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias))
|
||||
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
|
||||
def test_set_grad_sample_if_exists(self, device):
|
||||
def test_fn(_):
|
||||
def test_fn(a):
|
||||
return True
|
||||
|
||||
orig_weight = torch.randn(4, device=device, requires_grad=True)
|
||||
expanded_weight = ExpandedWeight(orig_weight, 3)
|
||||
expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
|
||||
set_grad_sample_if_exists(expanded_weight, test_fn)
|
||||
self.assertTrue(hasattr(orig_weight, 'grad_sample'))
|
||||
self.assertTrue(orig_weight.grad_sample)
|
||||
@ -86,7 +94,7 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
self.assertFalse(hasattr(non_tensor, 'grad_sample'))
|
||||
|
||||
def test_set_grad_sample_if_exists_failure(self, device):
|
||||
def test_fn(_):
|
||||
def test_fn(a):
|
||||
return True
|
||||
|
||||
grad_tensor = torch.randn(4, requires_grad=True, device=device)
|
||||
@ -95,7 +103,7 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
|
||||
def test_unpack_expanded_weight_or_tensor(self, device):
|
||||
input = torch.randn(3, requires_grad=True, device=device)
|
||||
self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3)))
|
||||
self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum")))
|
||||
|
||||
input.requires_grad_(False)
|
||||
self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
|
||||
@ -103,7 +111,7 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
|
||||
def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
|
||||
input = torch.randn(3, requires_grad=True, device=device)
|
||||
self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3), lambda x: x is input))
|
||||
self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input))
|
||||
|
||||
input.requires_grad_(False)
|
||||
self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
|
||||
@ -131,38 +139,58 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
self.assertEqual(res, input)
|
||||
|
||||
class TestExpandedWeightFunctional(TestCase):
|
||||
def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction):
|
||||
input = sample_input.input
|
||||
args = sample_input.args
|
||||
kwargs = sample_input.kwargs
|
||||
batch_size = input.shape[0] if len(input.shape) > 1 else 1
|
||||
|
||||
# get per sample grads with ExpandedWeights objects
|
||||
loss_reduction = "sum" if reduction == torch.sum else "mean"
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction)
|
||||
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
|
||||
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
|
||||
diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list]
|
||||
if not diff_input_list:
|
||||
return
|
||||
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
reduction(result).backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
|
||||
expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list)
|
||||
|
||||
# get per sample grads with for loop
|
||||
func = partial(run_op, op)
|
||||
|
||||
per_sample_grad = for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs)
|
||||
|
||||
# check equality
|
||||
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
|
||||
if loss_reduction == "mean":
|
||||
# don't check equality of `input.grad`s since these vanilla tensors won't be scaled
|
||||
expanded_weight_grad = expanded_weight_grad[1:]
|
||||
per_sample_grad = per_sample_grad[1:]
|
||||
for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad):
|
||||
self.assertEqual(result_grad, expected_grad)
|
||||
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_expanded_weight_per_sample_grad(self, device, dtype, op):
|
||||
def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op):
|
||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
for sample_input in supported_inputs(op, sample_inputs):
|
||||
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
|
||||
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
|
||||
input = sample_input.input
|
||||
args = sample_input.args
|
||||
kwargs = sample_input.kwargs
|
||||
batch_size = input.shape[0] if len(input.shape) > 1 else 1
|
||||
|
||||
# get per sample grads with ExpandedWeights objects
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
|
||||
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
|
||||
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
|
||||
diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list]
|
||||
if not diff_input_list:
|
||||
continue
|
||||
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
|
||||
expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list)
|
||||
def reduction(x):
|
||||
return x.sum()
|
||||
|
||||
# get per sample grads with for loop
|
||||
func = partial(run_op, op)
|
||||
per_sample_grad = for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs)
|
||||
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
|
||||
|
||||
# check equality
|
||||
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
|
||||
for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad):
|
||||
if result_grad is None:
|
||||
result_grad = torch.zeros_like(expected_grad)
|
||||
self.assertEqual(result_grad, expected_grad)
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op):
|
||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
for sample_input in supported_inputs(op, sample_inputs):
|
||||
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
|
||||
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
|
||||
|
||||
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
|
||||
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_unsupported_expand_weights(self, device, dtype, op):
|
||||
@ -195,34 +223,35 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
if "cuda" in device and "max_norm" in sample_input.kwargs and "padding_idx" in sample_input.kwargs:
|
||||
self.skipTest("embedding is non-determinstic in this case, see issue #74679")
|
||||
batch_size = sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
|
||||
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs)
|
||||
self.assertEqual(expanded_weight_result, normal_result)
|
||||
for loss_reduction in ["sum", "mean"]:
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction)
|
||||
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs)
|
||||
self.assertEqual(expanded_weight_result, normal_result)
|
||||
|
||||
def test_expanded_weight_error(self, device):
|
||||
batch_size = 3
|
||||
sample_input = make_tensor((batch_size, 4), dtype=torch.float32, device=device, requires_grad=True)
|
||||
sample_weight = make_tensor((4), dtype=torch.float32, device=device, requires_grad=True)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expanded Weights encountered but cannot handle function"):
|
||||
torch.add(sample_input, ExpandedWeight(sample_weight, batch_size))
|
||||
torch.add(sample_input, ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"))
|
||||
|
||||
def _test_embedding_model(self, model, num_embedding, device):
|
||||
batch_size = 32
|
||||
input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device)
|
||||
return self._test_model(partial(model, num_embedding=num_embedding), batch_size, input, device)
|
||||
|
||||
def _test_conv_model(self, model, input_size, num_dim, device):
|
||||
def _test_conv_model(self, model, input_size, num_dim, device, loss_reduction="sum"):
|
||||
batch_size = 32
|
||||
input_ending = [input_size] * num_dim
|
||||
input = torch.randn([batch_size, 3] + input_ending, device=device)
|
||||
return self._test_model(partial(model, num_dim=num_dim), batch_size, input, device)
|
||||
return self._test_model(partial(model, num_dim=num_dim), batch_size, input, device, loss_reduction)
|
||||
|
||||
def _test_model(self, model, batch_size, input, device):
|
||||
def _test_model(self, model, batch_size, input, device, loss_reduction="sum"):
|
||||
model = model(10).to(device)
|
||||
targets = torch.randint(0, 10, (batch_size,), device=device)
|
||||
criterion = CrossEntropyLoss(reduction='sum') # use a loss that doesn't average across the batch to test in a for loop
|
||||
result = call_for_per_sample_grads(model, batch_size, input)
|
||||
criterion = CrossEntropyLoss(reduction=loss_reduction)
|
||||
result = call_for_per_sample_grads(model, batch_size, loss_reduction=loss_reduction)(input)
|
||||
loss = criterion(result, targets)
|
||||
loss.backward()
|
||||
result = []
|
||||
@ -240,7 +269,7 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
|
||||
|
||||
|
||||
def test_cnn_model(self, device):
|
||||
def test_cnn_model_sum(self, device):
|
||||
def convnet(num_classes, num_dim):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
|
||||
@ -261,6 +290,27 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
|
||||
return self._test_conv_model(convnet, 28, 2, device)
|
||||
|
||||
def test_cnn_model_mean(self, device):
|
||||
def convnet(num_classes, num_dim):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Flatten(start_dim=1, end_dim=-1),
|
||||
nn.Linear(128, num_classes, bias=True),
|
||||
)
|
||||
|
||||
return self._test_conv_model(convnet, 28, 2, device, loss_reduction="mean")
|
||||
|
||||
@parametrize('num_dim', [1, 2, 3])
|
||||
def test_instance_norm_model(self, num_dim, device):
|
||||
def instance_norm_model(num_classes, num_dim):
|
||||
@ -326,7 +376,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
input.requires_grad_()
|
||||
with freeze_rng_state():
|
||||
# get per sample grads with ExpandedWeights context manager
|
||||
actual_res = call_for_per_sample_grads(module, batch_size, input).sum()
|
||||
actual_res = call_for_per_sample_grads(module, batch_size, loss_reduction="sum")(input).sum()
|
||||
actual_res.backward()
|
||||
actual_grads = []
|
||||
for param in module.parameters():
|
||||
@ -368,7 +418,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
with freeze_rng_state():
|
||||
# get per sample grads with ExpandedWeights context manager, calling .backward() twice
|
||||
test_module = TestModule(module)
|
||||
actual_res = call_for_per_sample_grads(test_module, batch_size, input).sum()
|
||||
actual_res = call_for_per_sample_grads(test_module, batch_size, loss_reduction="sum")(input).sum()
|
||||
actual_res.backward()
|
||||
actual_grads = []
|
||||
for param in module.parameters():
|
||||
@ -397,15 +447,17 @@ class TestExpandedWeightModule(TestCase):
|
||||
module = nn.Linear(10, 10)
|
||||
input = torch.randn(64, 10)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
|
||||
call_for_per_sample_grads("fail", 64, input)
|
||||
call_for_per_sample_grads("fail", 64)(input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Batch size passed must be an integer"):
|
||||
call_for_per_sample_grads(module, 6.4, input)
|
||||
call_for_per_sample_grads(module, 6.4)(input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
|
||||
call_for_per_sample_grads(module, -64, input)
|
||||
call_for_per_sample_grads(module, -64)(input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
|
||||
loss = call_for_per_sample_grads(module, 64, input).sum()
|
||||
loss = call_for_per_sample_grads(module, 64)(input).sum()
|
||||
loss.backward() # populate grad_sample fields
|
||||
call_for_per_sample_grads(module, 64, input)
|
||||
call_for_per_sample_grads(module, 64)(input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected loss_reduction argument to be sum or mean"):
|
||||
call_for_per_sample_grads(module, -64, loss_reduction="")(input)
|
||||
|
||||
class ContextManagerTests(TestBase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -478,9 +530,11 @@ def run_op(op, input, *args, **kwargs):
|
||||
else:
|
||||
return op(input, *args, **kwargs)
|
||||
|
||||
def make_expanded_weight(sample_input, batch_size):
|
||||
def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"):
|
||||
def expanded_weight_or_clone(arg):
|
||||
return ExpandedWeight(torch.clone(arg), batch_size) if is_diff_tensor(arg) else clone_if_tensor(arg)
|
||||
if is_diff_tensor(arg):
|
||||
return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
|
||||
return clone_if_tensor(arg)
|
||||
|
||||
ew_input = clone_if_tensor(sample_input.input)
|
||||
ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
|
||||
@ -513,12 +567,12 @@ def supported_inputs(op, sample_inputs, supported_inputs=True):
|
||||
return is_supported_input if supported_inputs else not is_supported_input
|
||||
return [input for input in sample_inputs if filter_fn(input)]
|
||||
|
||||
def for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs):
|
||||
def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs):
|
||||
# get per sample grads by getting derivative for each input in a for loop
|
||||
per_sample_grad = []
|
||||
for i in range(batch_size):
|
||||
per_sample_input = input[i]
|
||||
result = func(per_sample_input.unsqueeze(0), *args, **kwargs)
|
||||
result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs))
|
||||
diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
|
||||
diff_input_list = [i for i in diff_input_list if isinstance(i, torch.Tensor) and i.requires_grad]
|
||||
per_sample_grad.append(torch.autograd.grad(result, diff_input_list, torch.ones_like(result), allow_unused=True))
|
||||
|
@ -26,13 +26,14 @@ def implements_per_sample_grads(torch_function):
|
||||
#
|
||||
# Needs to be a tensor subclass to allow reparamaterization
|
||||
class ExpandedWeight(torch.Tensor):
|
||||
def __init__(self, orig_weight, batch_size):
|
||||
def __init__(self, orig_weight, batch_size, loss_reduction):
|
||||
self.batch_size = batch_size
|
||||
self.orig_weight = orig_weight
|
||||
self.loss_reduction = loss_reduction
|
||||
|
||||
handled_functions = HANDLED_FUNCTIONS
|
||||
|
||||
def __new__(cls, orig_weight, _):
|
||||
def __new__(cls, orig_weight, batch_size, loss_reduction):
|
||||
if not isinstance(orig_weight, torch.Tensor):
|
||||
raise RuntimeError(f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}")
|
||||
if not orig_weight.requires_grad:
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from .expanded_weights_impl import ExpandedWeight
|
||||
|
||||
@ -52,18 +54,34 @@ def _check_and_unexpand_args(func, expanded_args, expanded_kwargs):
|
||||
raise RuntimeError("Expected ExpandedWeights to have batch size matching input but got "
|
||||
f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}")
|
||||
|
||||
loss_reduction: Optional[str] = None
|
||||
for arg in expanded_args + tuple(expanded_kwargs.values()):
|
||||
if isinstance(arg, ExpandedWeight):
|
||||
if loss_reduction is None:
|
||||
loss_reduction = arg.loss_reduction
|
||||
elif loss_reduction != arg.loss_reduction:
|
||||
raise RuntimeError("Expected ExpandedWeights to all have the same loss_reduction argument but got one"
|
||||
f"with {loss_reduction} and one with {arg.loss_reduction}")
|
||||
|
||||
unexpanded_args = tuple(arg.orig_weight if isinstance(arg, ExpandedWeight) else arg for arg in expanded_args)
|
||||
unexpanded_kwargs = {name: arg.orig_weight if isinstance(arg, ExpandedWeight) else arg
|
||||
for (name, arg) in expanded_kwargs.items()}
|
||||
return unexpanded_args, unexpanded_kwargs
|
||||
|
||||
def maybe_scale_by_batch_size(grad_sample, expanded_weight):
|
||||
if expanded_weight.loss_reduction == "mean":
|
||||
return grad_sample * expanded_weight.batch_size
|
||||
else:
|
||||
return grad_sample
|
||||
|
||||
def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn):
|
||||
unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight)
|
||||
if isinstance(maybe_expanded_weight, ExpandedWeight):
|
||||
grad_sample_contribution = maybe_scale_by_batch_size(per_sample_grad_fn(unpacked), maybe_expanded_weight)
|
||||
if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None:
|
||||
unpacked.grad_sample = unpacked.grad_sample + per_sample_grad_fn(unpacked)
|
||||
unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution
|
||||
else:
|
||||
unpacked.grad_sample = per_sample_grad_fn(unpacked)
|
||||
unpacked.grad_sample = grad_sample_contribution
|
||||
|
||||
def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x):
|
||||
if isinstance(maybe_expanded_weight, ExpandedWeight):
|
||||
|
@ -1,29 +1,31 @@
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch.nn.utils._stateless import functional_call
|
||||
from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight
|
||||
|
||||
# dependency on `functional_call` means that this can't be exposed in utils
|
||||
# without creating circular dependency
|
||||
def call_for_per_sample_grads(module, batch_size, args, kwargs=None):
|
||||
def call_for_per_sample_grads(module, batch_size, loss_reduction="sum"):
|
||||
r"""
|
||||
call_for_per_sample_grads(module, batch_size, args, kwargs=None) -> Tensor
|
||||
Invoked just like a forward pass, ``call_for_per_sample_grads`` will produce the same
|
||||
forward result. Then, when backward is invoked, the parameters of ``module``
|
||||
will have a ``grad_sample`` field populated with the per sample gradients
|
||||
instead of the regular gradients
|
||||
call_for_per_sample_grads(module, batch_size, loss_reduction="sum")
|
||||
``call_for_per_sample_grads`` returns a function that is invoked like the forward
|
||||
function of ``module`` and will produce the same result. Then, when backward is invoked,
|
||||
the parameters of ``module`` will have a ``grad_sample`` field populated with the per sample
|
||||
gradients instead of the regular gradients
|
||||
|
||||
Args:
|
||||
module: The ``nn.Module`` to get per sample gradients with respect to. All trainable
|
||||
parameters will compute per sample gradients, located in a ``grad_sample``
|
||||
field when ``backward`` is invoked
|
||||
batch_size: The batch size of the input. Typically the input's first dimension
|
||||
args: Tuple of positional args passed to ``module`` to perform the forward pass
|
||||
kwargs: Dict of named args passed to ``module`` to perform the forward pass. Default: None
|
||||
loss_reduction: The reduction used on the loss. If "mean", per sample gradients will be scaled by the batch size
|
||||
to offset the crossbatch interaction from running mean across a batch. Must be "mean" or "sum". Default: "sum"
|
||||
|
||||
Examples::
|
||||
>>> model = nn.Linear(4, 3)
|
||||
>>> batched_input = torch.randn(5, 4) # batch size of 5
|
||||
>>> res = call_for_per_sample_grads(model, batched_input.shape[0], batched_input).sum()
|
||||
>>> res = call_for_per_sample_grads(model, batched_input.shape[0])(batched_input).sum()
|
||||
>>> res.backward()
|
||||
>>> assert model.weight.shape == (3, 4)
|
||||
>>> assert model.weight.grad_sample.shape == (5, 3, 4)
|
||||
@ -36,12 +38,16 @@ def call_for_per_sample_grads(module, batch_size, args, kwargs=None):
|
||||
Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom
|
||||
rewrites that wrap an `nn.Linear` module. See Opacus for an example
|
||||
"""
|
||||
|
||||
def maybe_build_expanded_weight(og_tensor):
|
||||
if og_tensor.requires_grad:
|
||||
return ExpandedWeight(og_tensor, batch_size)
|
||||
return ExpandedWeight(og_tensor, batch_size, loss_reduction)
|
||||
else:
|
||||
return og_tensor
|
||||
|
||||
if loss_reduction not in ["sum", "mean"]:
|
||||
raise RuntimeError(f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}")
|
||||
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
raise RuntimeError(f"Module passed must be nn.Module, got {type(module).__name__}")
|
||||
if not isinstance(batch_size, int):
|
||||
@ -53,5 +59,9 @@ def call_for_per_sample_grads(module, batch_size, args, kwargs=None):
|
||||
raise RuntimeError("Current Expanded Weights accumulates the gradients, which will be incorrect for multiple "
|
||||
f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or "
|
||||
"post an issue to pytorch/pytorch to prioritize correct behavior")
|
||||
params = {name: maybe_build_expanded_weight(value) for (name, value) in module.named_parameters()}
|
||||
return functional_call(module, params, args, kwargs)
|
||||
|
||||
@functools.wraps(module.forward)
|
||||
def wrapper(*args, **kwargs):
|
||||
params = {name: maybe_build_expanded_weight(value) for (name, value) in module.named_parameters()}
|
||||
return functional_call(module, params, args, kwargs)
|
||||
return wrapper
|
||||
|
Reference in New Issue
Block a user