[Expanded Weights] fix loss reduction (#80892)

Two changes in here:
(1) Changes `call_for_per_sample_grads` to be curried. Old call looks like:
`call_for_per_sample_grads(module, batch_size, args, kwargs)`
New call looks like:
`call_for_per_sample_grads(module, batch_size, loss_reduction=loss_reduction)(args, kwargs)`

(2) Adds the ability to specify a loss reduction, to match what is done in Opacus. Opacus has a more complete explanation but essentially, they want the per sample gradient behavior to match what is happens in a for loop with a single example. This gets messed up if you use a mean reduction at the end since in a batch that ends up scaling all the grad_outputs by 1/batch_size, so we offset that by scaling all the grad_samples by batch_size if the loss_reduction is mean
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80892
Approved by: https://github.com/zou3519
This commit is contained in:
samdow
2022-07-18 13:19:16 -04:00
committed by PyTorch MergeBot
parent 71e16f9eef
commit 799bc645d9
4 changed files with 155 additions and 72 deletions

View File

@ -27,8 +27,12 @@ class TestExpandedWeightHelperFunction(TestCase):
weight = torch.randn(5, 4, device=device)
bias = torch.randn(5, device=device)
for (weight_batched, bias_batched) in product([True, False], [True, False]):
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3) if weight_batched else weight
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3) if bias_batched else bias
maybe_batched_weight = weight
maybe_batched_bias = bias
if weight_batched:
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3, loss_reduction="sum")
if bias_batched:
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3, loss_reduction="sum")
args = (input, maybe_batched_weight, maybe_batched_bias)
expanded_args, expanded_kwargs = standard_kwargs(('bias',), args)
res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
@ -45,7 +49,7 @@ class TestExpandedWeightHelperFunction(TestCase):
weight = torch.randn(5, 4, device=device)
bias = torch.randn(5, device=device)
with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."):
input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3)
input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum")
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias))
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"):
@ -61,18 +65,22 @@ class TestExpandedWeightHelperFunction(TestCase):
for (weight_batched, bias_batched) in product([True, False], [True, False]):
if not weight_batched and not bias_batched:
continue
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4) if weight_batched else weight
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4) if bias_batched else bias
maybe_batched_weight = weight
maybe_batched_bias = bias
if weight_batched:
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4, loss_reduction="sum")
if bias_batched:
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4, loss_reduction="sum")
with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"):
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias))
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
def test_set_grad_sample_if_exists(self, device):
def test_fn(_):
def test_fn(a):
return True
orig_weight = torch.randn(4, device=device, requires_grad=True)
expanded_weight = ExpandedWeight(orig_weight, 3)
expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
set_grad_sample_if_exists(expanded_weight, test_fn)
self.assertTrue(hasattr(orig_weight, 'grad_sample'))
self.assertTrue(orig_weight.grad_sample)
@ -86,7 +94,7 @@ class TestExpandedWeightHelperFunction(TestCase):
self.assertFalse(hasattr(non_tensor, 'grad_sample'))
def test_set_grad_sample_if_exists_failure(self, device):
def test_fn(_):
def test_fn(a):
return True
grad_tensor = torch.randn(4, requires_grad=True, device=device)
@ -95,7 +103,7 @@ class TestExpandedWeightHelperFunction(TestCase):
def test_unpack_expanded_weight_or_tensor(self, device):
input = torch.randn(3, requires_grad=True, device=device)
self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3)))
self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum")))
input.requires_grad_(False)
self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
@ -103,7 +111,7 @@ class TestExpandedWeightHelperFunction(TestCase):
def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
input = torch.randn(3, requires_grad=True, device=device)
self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3), lambda x: x is input))
self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input))
input.requires_grad_(False)
self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
@ -131,38 +139,58 @@ class TestExpandedWeightHelperFunction(TestCase):
self.assertEqual(res, input)
class TestExpandedWeightFunctional(TestCase):
def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction):
input = sample_input.input
args = sample_input.args
kwargs = sample_input.kwargs
batch_size = input.shape[0] if len(input.shape) > 1 else 1
# get per sample grads with ExpandedWeights objects
loss_reduction = "sum" if reduction == torch.sum else "mean"
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction)
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list]
if not diff_input_list:
return
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
reduction(result).backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list)
# get per sample grads with for loop
func = partial(run_op, op)
per_sample_grad = for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs)
# check equality
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
if loss_reduction == "mean":
# don't check equality of `input.grad`s since these vanilla tensors won't be scaled
expanded_weight_grad = expanded_weight_grad[1:]
per_sample_grad = per_sample_grad[1:]
for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad):
self.assertEqual(result_grad, expected_grad)
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
def test_expanded_weight_per_sample_grad(self, device, dtype, op):
def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op):
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
for sample_input in supported_inputs(op, sample_inputs):
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
input = sample_input.input
args = sample_input.args
kwargs = sample_input.kwargs
batch_size = input.shape[0] if len(input.shape) > 1 else 1
# get per sample grads with ExpandedWeights objects
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list]
if not diff_input_list:
continue
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list)
def reduction(x):
return x.sum()
# get per sample grads with for loop
func = partial(run_op, op)
per_sample_grad = for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs)
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
# check equality
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad):
if result_grad is None:
result_grad = torch.zeros_like(expected_grad)
self.assertEqual(result_grad, expected_grad)
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op):
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
for sample_input in supported_inputs(op, sample_inputs):
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
def test_unsupported_expand_weights(self, device, dtype, op):
@ -195,34 +223,35 @@ class TestExpandedWeightFunctional(TestCase):
if "cuda" in device and "max_norm" in sample_input.kwargs and "padding_idx" in sample_input.kwargs:
self.skipTest("embedding is non-determinstic in this case, see issue #74679")
batch_size = sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs)
self.assertEqual(expanded_weight_result, normal_result)
for loss_reduction in ["sum", "mean"]:
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction)
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs)
self.assertEqual(expanded_weight_result, normal_result)
def test_expanded_weight_error(self, device):
batch_size = 3
sample_input = make_tensor((batch_size, 4), dtype=torch.float32, device=device, requires_grad=True)
sample_weight = make_tensor((4), dtype=torch.float32, device=device, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, r"Expanded Weights encountered but cannot handle function"):
torch.add(sample_input, ExpandedWeight(sample_weight, batch_size))
torch.add(sample_input, ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"))
def _test_embedding_model(self, model, num_embedding, device):
batch_size = 32
input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device)
return self._test_model(partial(model, num_embedding=num_embedding), batch_size, input, device)
def _test_conv_model(self, model, input_size, num_dim, device):
def _test_conv_model(self, model, input_size, num_dim, device, loss_reduction="sum"):
batch_size = 32
input_ending = [input_size] * num_dim
input = torch.randn([batch_size, 3] + input_ending, device=device)
return self._test_model(partial(model, num_dim=num_dim), batch_size, input, device)
return self._test_model(partial(model, num_dim=num_dim), batch_size, input, device, loss_reduction)
def _test_model(self, model, batch_size, input, device):
def _test_model(self, model, batch_size, input, device, loss_reduction="sum"):
model = model(10).to(device)
targets = torch.randint(0, 10, (batch_size,), device=device)
criterion = CrossEntropyLoss(reduction='sum') # use a loss that doesn't average across the batch to test in a for loop
result = call_for_per_sample_grads(model, batch_size, input)
criterion = CrossEntropyLoss(reduction=loss_reduction)
result = call_for_per_sample_grads(model, batch_size, loss_reduction=loss_reduction)(input)
loss = criterion(result, targets)
loss.backward()
result = []
@ -240,7 +269,7 @@ class TestExpandedWeightFunctional(TestCase):
self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
def test_cnn_model(self, device):
def test_cnn_model_sum(self, device):
def convnet(num_classes, num_dim):
return nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
@ -261,6 +290,27 @@ class TestExpandedWeightFunctional(TestCase):
return self._test_conv_model(convnet, 28, 2, device)
def test_cnn_model_mean(self, device):
def convnet(num_classes, num_dim):
return nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(start_dim=1, end_dim=-1),
nn.Linear(128, num_classes, bias=True),
)
return self._test_conv_model(convnet, 28, 2, device, loss_reduction="mean")
@parametrize('num_dim', [1, 2, 3])
def test_instance_norm_model(self, num_dim, device):
def instance_norm_model(num_classes, num_dim):
@ -326,7 +376,7 @@ class TestExpandedWeightModule(TestCase):
input.requires_grad_()
with freeze_rng_state():
# get per sample grads with ExpandedWeights context manager
actual_res = call_for_per_sample_grads(module, batch_size, input).sum()
actual_res = call_for_per_sample_grads(module, batch_size, loss_reduction="sum")(input).sum()
actual_res.backward()
actual_grads = []
for param in module.parameters():
@ -368,7 +418,7 @@ class TestExpandedWeightModule(TestCase):
with freeze_rng_state():
# get per sample grads with ExpandedWeights context manager, calling .backward() twice
test_module = TestModule(module)
actual_res = call_for_per_sample_grads(test_module, batch_size, input).sum()
actual_res = call_for_per_sample_grads(test_module, batch_size, loss_reduction="sum")(input).sum()
actual_res.backward()
actual_grads = []
for param in module.parameters():
@ -397,15 +447,17 @@ class TestExpandedWeightModule(TestCase):
module = nn.Linear(10, 10)
input = torch.randn(64, 10)
with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
call_for_per_sample_grads("fail", 64, input)
call_for_per_sample_grads("fail", 64)(input)
with self.assertRaisesRegex(RuntimeError, r"Batch size passed must be an integer"):
call_for_per_sample_grads(module, 6.4, input)
call_for_per_sample_grads(module, 6.4)(input)
with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
call_for_per_sample_grads(module, -64, input)
call_for_per_sample_grads(module, -64)(input)
with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
loss = call_for_per_sample_grads(module, 64, input).sum()
loss = call_for_per_sample_grads(module, 64)(input).sum()
loss.backward() # populate grad_sample fields
call_for_per_sample_grads(module, 64, input)
call_for_per_sample_grads(module, 64)(input)
with self.assertRaisesRegex(RuntimeError, r"Expected loss_reduction argument to be sum or mean"):
call_for_per_sample_grads(module, -64, loss_reduction="")(input)
class ContextManagerTests(TestBase):
def __init__(self, *args, **kwargs):
@ -478,9 +530,11 @@ def run_op(op, input, *args, **kwargs):
else:
return op(input, *args, **kwargs)
def make_expanded_weight(sample_input, batch_size):
def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"):
def expanded_weight_or_clone(arg):
return ExpandedWeight(torch.clone(arg), batch_size) if is_diff_tensor(arg) else clone_if_tensor(arg)
if is_diff_tensor(arg):
return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
return clone_if_tensor(arg)
ew_input = clone_if_tensor(sample_input.input)
ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
@ -513,12 +567,12 @@ def supported_inputs(op, sample_inputs, supported_inputs=True):
return is_supported_input if supported_inputs else not is_supported_input
return [input for input in sample_inputs if filter_fn(input)]
def for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs):
def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs):
# get per sample grads by getting derivative for each input in a for loop
per_sample_grad = []
for i in range(batch_size):
per_sample_input = input[i]
result = func(per_sample_input.unsqueeze(0), *args, **kwargs)
result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs))
diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
diff_input_list = [i for i in diff_input_list if isinstance(i, torch.Tensor) and i.requires_grad]
per_sample_grad.append(torch.autograd.grad(result, diff_input_list, torch.ones_like(result), allow_unused=True))