mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Expanded Weights] fix loss reduction (#80892)
Two changes in here: (1) Changes `call_for_per_sample_grads` to be curried. Old call looks like: `call_for_per_sample_grads(module, batch_size, args, kwargs)` New call looks like: `call_for_per_sample_grads(module, batch_size, loss_reduction=loss_reduction)(args, kwargs)` (2) Adds the ability to specify a loss reduction, to match what is done in Opacus. Opacus has a more complete explanation but essentially, they want the per sample gradient behavior to match what is happens in a for loop with a single example. This gets messed up if you use a mean reduction at the end since in a batch that ends up scaling all the grad_outputs by 1/batch_size, so we offset that by scaling all the grad_samples by batch_size if the loss_reduction is mean Pull Request resolved: https://github.com/pytorch/pytorch/pull/80892 Approved by: https://github.com/zou3519
This commit is contained in:
@ -27,8 +27,12 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
weight = torch.randn(5, 4, device=device)
|
||||
bias = torch.randn(5, device=device)
|
||||
for (weight_batched, bias_batched) in product([True, False], [True, False]):
|
||||
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3) if weight_batched else weight
|
||||
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3) if bias_batched else bias
|
||||
maybe_batched_weight = weight
|
||||
maybe_batched_bias = bias
|
||||
if weight_batched:
|
||||
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3, loss_reduction="sum")
|
||||
if bias_batched:
|
||||
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3, loss_reduction="sum")
|
||||
args = (input, maybe_batched_weight, maybe_batched_bias)
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), args)
|
||||
res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
@ -45,7 +49,7 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
weight = torch.randn(5, 4, device=device)
|
||||
bias = torch.randn(5, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."):
|
||||
input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3)
|
||||
input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum")
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias))
|
||||
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"):
|
||||
@ -61,18 +65,22 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
for (weight_batched, bias_batched) in product([True, False], [True, False]):
|
||||
if not weight_batched and not bias_batched:
|
||||
continue
|
||||
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4) if weight_batched else weight
|
||||
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4) if bias_batched else bias
|
||||
maybe_batched_weight = weight
|
||||
maybe_batched_bias = bias
|
||||
if weight_batched:
|
||||
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4, loss_reduction="sum")
|
||||
if bias_batched:
|
||||
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4, loss_reduction="sum")
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias))
|
||||
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
|
||||
def test_set_grad_sample_if_exists(self, device):
|
||||
def test_fn(_):
|
||||
def test_fn(a):
|
||||
return True
|
||||
|
||||
orig_weight = torch.randn(4, device=device, requires_grad=True)
|
||||
expanded_weight = ExpandedWeight(orig_weight, 3)
|
||||
expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
|
||||
set_grad_sample_if_exists(expanded_weight, test_fn)
|
||||
self.assertTrue(hasattr(orig_weight, 'grad_sample'))
|
||||
self.assertTrue(orig_weight.grad_sample)
|
||||
@ -86,7 +94,7 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
self.assertFalse(hasattr(non_tensor, 'grad_sample'))
|
||||
|
||||
def test_set_grad_sample_if_exists_failure(self, device):
|
||||
def test_fn(_):
|
||||
def test_fn(a):
|
||||
return True
|
||||
|
||||
grad_tensor = torch.randn(4, requires_grad=True, device=device)
|
||||
@ -95,7 +103,7 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
|
||||
def test_unpack_expanded_weight_or_tensor(self, device):
|
||||
input = torch.randn(3, requires_grad=True, device=device)
|
||||
self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3)))
|
||||
self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum")))
|
||||
|
||||
input.requires_grad_(False)
|
||||
self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
|
||||
@ -103,7 +111,7 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
|
||||
def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
|
||||
input = torch.randn(3, requires_grad=True, device=device)
|
||||
self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3), lambda x: x is input))
|
||||
self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input))
|
||||
|
||||
input.requires_grad_(False)
|
||||
self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
|
||||
@ -131,38 +139,58 @@ class TestExpandedWeightHelperFunction(TestCase):
|
||||
self.assertEqual(res, input)
|
||||
|
||||
class TestExpandedWeightFunctional(TestCase):
|
||||
def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction):
|
||||
input = sample_input.input
|
||||
args = sample_input.args
|
||||
kwargs = sample_input.kwargs
|
||||
batch_size = input.shape[0] if len(input.shape) > 1 else 1
|
||||
|
||||
# get per sample grads with ExpandedWeights objects
|
||||
loss_reduction = "sum" if reduction == torch.sum else "mean"
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction)
|
||||
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
|
||||
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
|
||||
diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list]
|
||||
if not diff_input_list:
|
||||
return
|
||||
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
reduction(result).backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
|
||||
expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list)
|
||||
|
||||
# get per sample grads with for loop
|
||||
func = partial(run_op, op)
|
||||
|
||||
per_sample_grad = for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs)
|
||||
|
||||
# check equality
|
||||
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
|
||||
if loss_reduction == "mean":
|
||||
# don't check equality of `input.grad`s since these vanilla tensors won't be scaled
|
||||
expanded_weight_grad = expanded_weight_grad[1:]
|
||||
per_sample_grad = per_sample_grad[1:]
|
||||
for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad):
|
||||
self.assertEqual(result_grad, expected_grad)
|
||||
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_expanded_weight_per_sample_grad(self, device, dtype, op):
|
||||
def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op):
|
||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
for sample_input in supported_inputs(op, sample_inputs):
|
||||
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
|
||||
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
|
||||
input = sample_input.input
|
||||
args = sample_input.args
|
||||
kwargs = sample_input.kwargs
|
||||
batch_size = input.shape[0] if len(input.shape) > 1 else 1
|
||||
|
||||
# get per sample grads with ExpandedWeights objects
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
|
||||
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
|
||||
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
|
||||
diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list]
|
||||
if not diff_input_list:
|
||||
continue
|
||||
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
|
||||
expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list)
|
||||
def reduction(x):
|
||||
return x.sum()
|
||||
|
||||
# get per sample grads with for loop
|
||||
func = partial(run_op, op)
|
||||
per_sample_grad = for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs)
|
||||
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
|
||||
|
||||
# check equality
|
||||
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
|
||||
for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad):
|
||||
if result_grad is None:
|
||||
result_grad = torch.zeros_like(expected_grad)
|
||||
self.assertEqual(result_grad, expected_grad)
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op):
|
||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
for sample_input in supported_inputs(op, sample_inputs):
|
||||
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
|
||||
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
|
||||
|
||||
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
|
||||
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_unsupported_expand_weights(self, device, dtype, op):
|
||||
@ -195,34 +223,35 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
if "cuda" in device and "max_norm" in sample_input.kwargs and "padding_idx" in sample_input.kwargs:
|
||||
self.skipTest("embedding is non-determinstic in this case, see issue #74679")
|
||||
batch_size = sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
|
||||
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs)
|
||||
self.assertEqual(expanded_weight_result, normal_result)
|
||||
for loss_reduction in ["sum", "mean"]:
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction)
|
||||
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs)
|
||||
self.assertEqual(expanded_weight_result, normal_result)
|
||||
|
||||
def test_expanded_weight_error(self, device):
|
||||
batch_size = 3
|
||||
sample_input = make_tensor((batch_size, 4), dtype=torch.float32, device=device, requires_grad=True)
|
||||
sample_weight = make_tensor((4), dtype=torch.float32, device=device, requires_grad=True)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expanded Weights encountered but cannot handle function"):
|
||||
torch.add(sample_input, ExpandedWeight(sample_weight, batch_size))
|
||||
torch.add(sample_input, ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"))
|
||||
|
||||
def _test_embedding_model(self, model, num_embedding, device):
|
||||
batch_size = 32
|
||||
input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device)
|
||||
return self._test_model(partial(model, num_embedding=num_embedding), batch_size, input, device)
|
||||
|
||||
def _test_conv_model(self, model, input_size, num_dim, device):
|
||||
def _test_conv_model(self, model, input_size, num_dim, device, loss_reduction="sum"):
|
||||
batch_size = 32
|
||||
input_ending = [input_size] * num_dim
|
||||
input = torch.randn([batch_size, 3] + input_ending, device=device)
|
||||
return self._test_model(partial(model, num_dim=num_dim), batch_size, input, device)
|
||||
return self._test_model(partial(model, num_dim=num_dim), batch_size, input, device, loss_reduction)
|
||||
|
||||
def _test_model(self, model, batch_size, input, device):
|
||||
def _test_model(self, model, batch_size, input, device, loss_reduction="sum"):
|
||||
model = model(10).to(device)
|
||||
targets = torch.randint(0, 10, (batch_size,), device=device)
|
||||
criterion = CrossEntropyLoss(reduction='sum') # use a loss that doesn't average across the batch to test in a for loop
|
||||
result = call_for_per_sample_grads(model, batch_size, input)
|
||||
criterion = CrossEntropyLoss(reduction=loss_reduction)
|
||||
result = call_for_per_sample_grads(model, batch_size, loss_reduction=loss_reduction)(input)
|
||||
loss = criterion(result, targets)
|
||||
loss.backward()
|
||||
result = []
|
||||
@ -240,7 +269,7 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
|
||||
|
||||
|
||||
def test_cnn_model(self, device):
|
||||
def test_cnn_model_sum(self, device):
|
||||
def convnet(num_classes, num_dim):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
|
||||
@ -261,6 +290,27 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
|
||||
return self._test_conv_model(convnet, 28, 2, device)
|
||||
|
||||
def test_cnn_model_mean(self, device):
|
||||
def convnet(num_classes, num_dim):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Flatten(start_dim=1, end_dim=-1),
|
||||
nn.Linear(128, num_classes, bias=True),
|
||||
)
|
||||
|
||||
return self._test_conv_model(convnet, 28, 2, device, loss_reduction="mean")
|
||||
|
||||
@parametrize('num_dim', [1, 2, 3])
|
||||
def test_instance_norm_model(self, num_dim, device):
|
||||
def instance_norm_model(num_classes, num_dim):
|
||||
@ -326,7 +376,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
input.requires_grad_()
|
||||
with freeze_rng_state():
|
||||
# get per sample grads with ExpandedWeights context manager
|
||||
actual_res = call_for_per_sample_grads(module, batch_size, input).sum()
|
||||
actual_res = call_for_per_sample_grads(module, batch_size, loss_reduction="sum")(input).sum()
|
||||
actual_res.backward()
|
||||
actual_grads = []
|
||||
for param in module.parameters():
|
||||
@ -368,7 +418,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
with freeze_rng_state():
|
||||
# get per sample grads with ExpandedWeights context manager, calling .backward() twice
|
||||
test_module = TestModule(module)
|
||||
actual_res = call_for_per_sample_grads(test_module, batch_size, input).sum()
|
||||
actual_res = call_for_per_sample_grads(test_module, batch_size, loss_reduction="sum")(input).sum()
|
||||
actual_res.backward()
|
||||
actual_grads = []
|
||||
for param in module.parameters():
|
||||
@ -397,15 +447,17 @@ class TestExpandedWeightModule(TestCase):
|
||||
module = nn.Linear(10, 10)
|
||||
input = torch.randn(64, 10)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
|
||||
call_for_per_sample_grads("fail", 64, input)
|
||||
call_for_per_sample_grads("fail", 64)(input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Batch size passed must be an integer"):
|
||||
call_for_per_sample_grads(module, 6.4, input)
|
||||
call_for_per_sample_grads(module, 6.4)(input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
|
||||
call_for_per_sample_grads(module, -64, input)
|
||||
call_for_per_sample_grads(module, -64)(input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
|
||||
loss = call_for_per_sample_grads(module, 64, input).sum()
|
||||
loss = call_for_per_sample_grads(module, 64)(input).sum()
|
||||
loss.backward() # populate grad_sample fields
|
||||
call_for_per_sample_grads(module, 64, input)
|
||||
call_for_per_sample_grads(module, 64)(input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected loss_reduction argument to be sum or mean"):
|
||||
call_for_per_sample_grads(module, -64, loss_reduction="")(input)
|
||||
|
||||
class ContextManagerTests(TestBase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -478,9 +530,11 @@ def run_op(op, input, *args, **kwargs):
|
||||
else:
|
||||
return op(input, *args, **kwargs)
|
||||
|
||||
def make_expanded_weight(sample_input, batch_size):
|
||||
def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"):
|
||||
def expanded_weight_or_clone(arg):
|
||||
return ExpandedWeight(torch.clone(arg), batch_size) if is_diff_tensor(arg) else clone_if_tensor(arg)
|
||||
if is_diff_tensor(arg):
|
||||
return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
|
||||
return clone_if_tensor(arg)
|
||||
|
||||
ew_input = clone_if_tensor(sample_input.input)
|
||||
ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
|
||||
@ -513,12 +567,12 @@ def supported_inputs(op, sample_inputs, supported_inputs=True):
|
||||
return is_supported_input if supported_inputs else not is_supported_input
|
||||
return [input for input in sample_inputs if filter_fn(input)]
|
||||
|
||||
def for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs):
|
||||
def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs):
|
||||
# get per sample grads by getting derivative for each input in a for loop
|
||||
per_sample_grad = []
|
||||
for i in range(batch_size):
|
||||
per_sample_input = input[i]
|
||||
result = func(per_sample_input.unsqueeze(0), *args, **kwargs)
|
||||
result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs))
|
||||
diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
|
||||
diff_input_list = [i for i in diff_input_list if isinstance(i, torch.Tensor) and i.requires_grad]
|
||||
per_sample_grad.append(torch.autograd.grad(result, diff_input_list, torch.ones_like(result), allow_unused=True))
|
||||
|
Reference in New Issue
Block a user