mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This reverts commit 310f901a71e53688866b14bb2f2b4c8eef9979b3. Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3152337518))
1161 lines
45 KiB
Python
1161 lines
45 KiB
Python
# Owner(s): ["module: nn"]
|
|
import unittest
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
from itertools import chain, product
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import CrossEntropyLoss
|
|
from torch.nn.utils._expanded_weights import ExpandedWeight
|
|
from torch.nn.utils._expanded_weights.expanded_weights_utils import (
|
|
forward_helper,
|
|
set_grad_sample_if_exists,
|
|
standard_kwargs,
|
|
sum_over_all_but_batch_and_last_n,
|
|
unpack_expanded_weight_or_tensor,
|
|
)
|
|
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
|
|
from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
OpDTypes,
|
|
ops,
|
|
)
|
|
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
|
|
from torch.testing._internal.common_modules import module_db, modules
|
|
from torch.testing._internal.common_nn import (
|
|
get_new_module_tests,
|
|
module_tests,
|
|
TestBase,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
freeze_rng_state,
|
|
make_tensor,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
from torch.utils._pytree import tree_map_only
|
|
|
|
|
|
class TestContext:
|
|
pass
|
|
|
|
|
|
class TestExpandedWeightHelperFunction(TestCase):
|
|
def test_forward_helper(self, device):
|
|
input = torch.randn(3, 4, device=device)
|
|
weight = torch.randn(5, 4, device=device)
|
|
bias = torch.randn(5, device=device)
|
|
for weight_batched, bias_batched in product([True, False], [True, False]):
|
|
maybe_batched_weight = weight
|
|
maybe_batched_bias = bias
|
|
if weight_batched:
|
|
maybe_batched_weight = ExpandedWeight(
|
|
weight.clone().requires_grad_(), 3, loss_reduction="sum"
|
|
)
|
|
if bias_batched:
|
|
maybe_batched_bias = ExpandedWeight(
|
|
bias.clone().requires_grad_(), 3, loss_reduction="sum"
|
|
)
|
|
args = (input, maybe_batched_weight, maybe_batched_bias)
|
|
expanded_args, expanded_kwargs = standard_kwargs(("bias",), args)
|
|
res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
|
expected = nn.functional.linear(input, weight, bias)
|
|
self.assertEqual(res, expected)
|
|
|
|
self.assertEqual(len(expanded_args), 2)
|
|
assert expanded_args[0] is args[0] # avoids property checks in assertEquals
|
|
assert expanded_args[1] is args[1] # avoids property checks in assertEquals
|
|
self.assertEqual(len(expanded_kwargs), 1)
|
|
assert (
|
|
expanded_kwargs["bias"] is args[2]
|
|
) # avoids property checks in assertEquals
|
|
|
|
def test_forward_helper_failure_args(self, device):
|
|
weight = torch.randn(5, 4, device=device)
|
|
bias = torch.randn(5, device=device)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"do not support inputs that are also ExpandedWeights."
|
|
):
|
|
input = ExpandedWeight(
|
|
torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum"
|
|
)
|
|
expanded_args, expanded_kwargs = standard_kwargs(
|
|
("bias",), (input, weight, bias)
|
|
)
|
|
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"requires a Tensor as the first input"
|
|
):
|
|
expanded_args, expanded_kwargs = standard_kwargs(
|
|
("bias",), (3, weight, bias)
|
|
)
|
|
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"requires a batch dimension but got an input of size 0"
|
|
):
|
|
expanded_args, expanded_kwargs = standard_kwargs(
|
|
("bias",), (torch.tensor(3), weight, bias)
|
|
)
|
|
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"0 is not a valid batch size for Expanded Weights"
|
|
):
|
|
expanded_args, expanded_kwargs = standard_kwargs(
|
|
("bias",), (torch.randn(0, 1, 2), weight, bias)
|
|
)
|
|
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
|
input = torch.randn(3, 4)
|
|
for weight_batched, bias_batched in product([True, False], [True, False]):
|
|
if not weight_batched and not bias_batched:
|
|
continue
|
|
maybe_batched_weight = weight
|
|
maybe_batched_bias = bias
|
|
if weight_batched:
|
|
maybe_batched_weight = ExpandedWeight(
|
|
weight.clone().requires_grad_(), 4, loss_reduction="sum"
|
|
)
|
|
if bias_batched:
|
|
maybe_batched_bias = ExpandedWeight(
|
|
bias.clone().requires_grad_(), 4, loss_reduction="sum"
|
|
)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Expected ExpandedWeights to have batch size matching input",
|
|
):
|
|
expanded_args, expanded_kwargs = standard_kwargs(
|
|
("bias",), (input, maybe_batched_weight, maybe_batched_bias)
|
|
)
|
|
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
|
|
|
def test_set_grad_sample_if_exists(self, device):
|
|
def test_fn(a):
|
|
return grad_sample
|
|
|
|
orig_weight = torch.randn(4, device=device, requires_grad=True)
|
|
expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
|
|
grad_sample = torch.randn(3)
|
|
set_grad_sample_if_exists(expanded_weight, test_fn)
|
|
self.assertTrue(hasattr(orig_weight, "grad_sample"))
|
|
self.assertEqual(orig_weight.grad_sample, grad_sample)
|
|
|
|
basic_tensor = torch.randn(4, device=device)
|
|
set_grad_sample_if_exists(basic_tensor, test_fn)
|
|
self.assertFalse(hasattr(basic_tensor, "grad_sample"))
|
|
|
|
non_tensor = 3
|
|
set_grad_sample_if_exists(non_tensor, test_fn)
|
|
self.assertFalse(hasattr(non_tensor, "grad_sample"))
|
|
|
|
def test_set_grad_sample_if_exists_failure(self, device):
|
|
def test_fn(a):
|
|
return True
|
|
|
|
grad_tensor = torch.randn(4, requires_grad=True, device=device)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
|
|
):
|
|
set_grad_sample_if_exists(grad_tensor, test_fn)
|
|
|
|
def test_unpack_expanded_weight_or_tensor(self, device):
|
|
input = torch.randn(3, requires_grad=True, device=device)
|
|
self.assertEqual(
|
|
input,
|
|
unpack_expanded_weight_or_tensor(
|
|
ExpandedWeight(input, 3, loss_reduction="sum")
|
|
),
|
|
)
|
|
|
|
input.requires_grad_(False)
|
|
self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
|
|
self.assertTrue(unpack_expanded_weight_or_tensor(4) is None)
|
|
|
|
def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
|
|
input = torch.randn(3, requires_grad=True, device=device)
|
|
self.assertTrue(
|
|
unpack_expanded_weight_or_tensor(
|
|
ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input
|
|
)
|
|
)
|
|
|
|
input.requires_grad_(False)
|
|
self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
|
|
self.assertTrue(
|
|
unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None
|
|
)
|
|
|
|
def test_unpack_expanded_weight_or_tensor_failure(self, device):
|
|
input = torch.randn(3, requires_grad=True, device=device)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
|
|
):
|
|
unpack_expanded_weight_or_tensor(input)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
|
|
):
|
|
unpack_expanded_weight_or_tensor(input, lambda x: x is input)
|
|
|
|
def test_sum_over_all_but_batch_and_last_n(self, device):
|
|
input = torch.randn(1, 2, 3, 4, 5, device=device)
|
|
res = sum_over_all_but_batch_and_last_n(input, 2)
|
|
expected = input.sum((1, 2))
|
|
self.assertEqual(res, expected)
|
|
|
|
res = sum_over_all_but_batch_and_last_n(input, 0)
|
|
expected = input.sum((1, 2, 3, 4))
|
|
self.assertEqual(res, expected)
|
|
|
|
res = sum_over_all_but_batch_and_last_n(input, 4)
|
|
self.assertEqual(res, input)
|
|
|
|
|
|
class TestExpandedWeightFunctional(TestCase):
|
|
def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction):
|
|
input = sample_input.input
|
|
args = sample_input.args
|
|
kwargs = sample_input.kwargs
|
|
batch_size = input.shape[0] if len(input.shape) > 1 else 1
|
|
|
|
# get per sample grads with ExpandedWeights objects
|
|
loss_reduction = "sum" if reduction == torch.sum else "mean"
|
|
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(
|
|
sample_input, batch_size, loss_reduction
|
|
)
|
|
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
|
|
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
|
|
diff_input_list = [
|
|
i.orig_weight if isinstance(i, ExpandedWeight) else i
|
|
for i in diff_input_list
|
|
]
|
|
if not diff_input_list:
|
|
return
|
|
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
|
reduction(
|
|
result
|
|
).backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
|
|
expanded_weight_grad = tuple(
|
|
i.grad_sample if hasattr(i, "grad_sample") else i.grad
|
|
for i in diff_input_list
|
|
)
|
|
|
|
# get per sample grads with for loop
|
|
func = partial(run_op, op)
|
|
|
|
per_sample_grad = for_loop_per_sample_grad(
|
|
batch_size, reduction, input, func, *args, **kwargs
|
|
)
|
|
|
|
# check equality
|
|
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
|
|
if loss_reduction == "mean":
|
|
# don't check equality of `input.grad`s since these vanilla tensors won't be scaled
|
|
expanded_weight_grad = expanded_weight_grad[1:]
|
|
per_sample_grad = per_sample_grad[1:]
|
|
for result_grad, expected_grad in zip(expanded_weight_grad, per_sample_grad):
|
|
self.assertEqual(result_grad, expected_grad)
|
|
|
|
@ops(
|
|
filter(lambda op: op.supports_expanded_weight, op_db),
|
|
dtypes=OpDTypes.supported,
|
|
allowed_dtypes=(torch.double,),
|
|
)
|
|
def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op):
|
|
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
|
for sample_input in supported_inputs(op, sample_inputs):
|
|
if (
|
|
op.name == "nn.functional.embedding"
|
|
): # embedding flips its argument order for autograd tests
|
|
sample_input = SampleInput(
|
|
sample_input.args[0],
|
|
args=(sample_input.input,),
|
|
kwargs=sample_input.kwargs,
|
|
)
|
|
|
|
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
|
|
|
|
@ops(
|
|
filter(lambda op: op.supports_expanded_weight, op_db),
|
|
dtypes=OpDTypes.supported,
|
|
allowed_dtypes=(torch.double,),
|
|
)
|
|
def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op):
|
|
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
|
for sample_input in supported_inputs(op, sample_inputs):
|
|
if (
|
|
op.name == "nn.functional.embedding"
|
|
): # embedding flips its argument order for autograd tests
|
|
sample_input = SampleInput(
|
|
sample_input.args[0],
|
|
args=(sample_input.input,),
|
|
kwargs=sample_input.kwargs,
|
|
)
|
|
|
|
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
|
|
|
|
@ops(
|
|
filter(lambda op: op.supports_expanded_weight, op_db),
|
|
dtypes=OpDTypes.supported,
|
|
allowed_dtypes=(torch.double,),
|
|
)
|
|
def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op):
|
|
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
|
for sample_input in supported_inputs(op, sample_inputs):
|
|
if (
|
|
op.name == "nn.functional.embedding"
|
|
): # embedding flips its argument order for autograd tests
|
|
sample_input = SampleInput(
|
|
sample_input.args[0],
|
|
args=(sample_input.input,),
|
|
kwargs=sample_input.kwargs,
|
|
)
|
|
sample_input.input.requires_grad_(False)
|
|
|
|
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
|
|
|
|
@skipIfTorchDynamo("Checking error message doesn't work with dynamo")
|
|
@ops(
|
|
filter(lambda op: op.supports_expanded_weight, op_db),
|
|
dtypes=OpDTypes.supported,
|
|
allowed_dtypes=(torch.double,),
|
|
)
|
|
def test_unsupported_expand_weights(self, device, dtype, op):
|
|
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
|
unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False)
|
|
for sample_input in unsupported_inputs:
|
|
with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"):
|
|
if (
|
|
op.name == "nn.functional.embedding"
|
|
): # embedding flips its argument order for autograd tests
|
|
sample_input = SampleInput(
|
|
sample_input.args[0],
|
|
args=(sample_input.input,),
|
|
kwargs=sample_input.kwargs,
|
|
)
|
|
input = sample_input.input
|
|
|
|
batch_size = input.shape[0] if len(input.shape) > 1 else 1
|
|
|
|
# get per sample grads with ExpandedWeights objects
|
|
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(
|
|
sample_input, batch_size
|
|
)
|
|
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
|
diff_input_list = (
|
|
(ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
|
|
)
|
|
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
|
|
diff_input_list = [
|
|
i.orig_weight if isinstance(i, ExpandedWeight) else i
|
|
for i in diff_input_list
|
|
]
|
|
result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
|
|
|
|
@ops(
|
|
filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported
|
|
)
|
|
def test_expanded_weight_forward(self, device, dtype, op):
|
|
sample_inputs = op.sample_inputs(device, dtype)
|
|
for sample_input in supported_inputs(op, sample_inputs):
|
|
if (
|
|
op.name == "nn.functional.embedding"
|
|
): # embedding flips its argument order for autograd tests
|
|
sample_input = SampleInput(
|
|
sample_input.args[0].clone(),
|
|
args=(sample_input.input.clone(),),
|
|
kwargs=sample_input.kwargs,
|
|
)
|
|
if (
|
|
"cuda" in device
|
|
and "max_norm" in sample_input.kwargs
|
|
and "padding_idx" in sample_input.kwargs
|
|
):
|
|
self.skipTest(
|
|
"embedding is non-determinstic in this case, see issue #74679"
|
|
)
|
|
batch_size = (
|
|
sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
|
|
)
|
|
for loss_reduction in ["sum", "mean"]:
|
|
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(
|
|
sample_input, batch_size, loss_reduction
|
|
)
|
|
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
|
normal_result = run_op(
|
|
op, sample_input.input, *sample_input.args, **sample_input.kwargs
|
|
)
|
|
self.assertEqual(expanded_weight_result, normal_result)
|
|
|
|
def test_expanded_weight_error(self, device):
|
|
batch_size = 3
|
|
sample_input = make_tensor(
|
|
(batch_size, 4), dtype=torch.float32, device=device, requires_grad=True
|
|
)
|
|
sample_weight = make_tensor(
|
|
(4), dtype=torch.float32, device=device, requires_grad=True
|
|
)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"Expanded Weights encountered but cannot handle function"
|
|
):
|
|
torch.add(
|
|
sample_input,
|
|
ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"),
|
|
)
|
|
|
|
def _test_embedding_model(self, model, num_embedding, device):
|
|
batch_size = 32
|
|
input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device)
|
|
return self._test_model(
|
|
partial(model, num_embedding=num_embedding), batch_size, input, device
|
|
)
|
|
|
|
def _test_conv_model(
|
|
self,
|
|
model,
|
|
input_size,
|
|
num_dim,
|
|
device,
|
|
loss_reduction="sum",
|
|
atol=1e-4,
|
|
rtol=5e-5,
|
|
):
|
|
batch_size = 32
|
|
input_ending = [input_size] * num_dim
|
|
input = torch.randn([batch_size, 3] + input_ending, device=device)
|
|
return self._test_model(
|
|
partial(model, num_dim=num_dim),
|
|
batch_size,
|
|
input,
|
|
device,
|
|
loss_reduction,
|
|
atol,
|
|
rtol,
|
|
)
|
|
|
|
def _test_model(
|
|
self,
|
|
model,
|
|
batch_size,
|
|
input,
|
|
device,
|
|
loss_reduction="sum",
|
|
atol=1e-4,
|
|
rtol=5e-5,
|
|
):
|
|
model = model(10).to(device)
|
|
targets = torch.randint(0, 10, (batch_size,), device=device)
|
|
criterion = CrossEntropyLoss(reduction=loss_reduction)
|
|
result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input)
|
|
loss = criterion(result, targets)
|
|
loss.backward()
|
|
result = []
|
|
for weight in model.parameters():
|
|
result.append(weight.grad_sample)
|
|
del weight.grad_sample
|
|
|
|
expected = []
|
|
for i in range(batch_size):
|
|
loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0))
|
|
expected.append(
|
|
torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))
|
|
)
|
|
|
|
expected = [torch.stack(grad) for grad in zip(*expected)]
|
|
for res, exp in zip(result, expected):
|
|
self.assertEqual(res, exp, atol=atol, rtol=rtol)
|
|
|
|
def _compute_tolerances(self, device):
|
|
is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(
|
|
0
|
|
) == (8, 6)
|
|
return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5)
|
|
|
|
@tf32_off()
|
|
def test_cnn_model_sum(self, device):
|
|
def convnet(num_classes, num_dim):
|
|
return nn.Sequential(
|
|
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.AvgPool2d(kernel_size=2, stride=2),
|
|
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.AvgPool2d(kernel_size=2, stride=2),
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.AvgPool2d(kernel_size=2, stride=2),
|
|
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.AdaptiveAvgPool2d((1, 1)),
|
|
nn.Flatten(start_dim=1, end_dim=-1),
|
|
nn.Linear(128, num_classes, bias=True),
|
|
)
|
|
|
|
atol, rtol = self._compute_tolerances(device)
|
|
return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol)
|
|
|
|
@tf32_off()
|
|
def test_cnn_model_mean(self, device):
|
|
def convnet(num_classes, num_dim):
|
|
return nn.Sequential(
|
|
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.AvgPool2d(kernel_size=2, stride=2),
|
|
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.AvgPool2d(kernel_size=2, stride=2),
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.AvgPool2d(kernel_size=2, stride=2),
|
|
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.AdaptiveAvgPool2d((1, 1)),
|
|
nn.Flatten(start_dim=1, end_dim=-1),
|
|
nn.Linear(128, num_classes, bias=True),
|
|
)
|
|
|
|
atol, rtol = self._compute_tolerances(device)
|
|
return self._test_conv_model(
|
|
convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol
|
|
)
|
|
|
|
@parametrize("num_dim", [1, 2, 3])
|
|
@tf32_off()
|
|
def test_instance_norm_model(self, num_dim, device):
|
|
def instance_norm_model(num_classes, num_dim):
|
|
conv_layer = (
|
|
nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
|
|
)
|
|
norm_layer = (
|
|
nn.InstanceNorm1d
|
|
if num_dim == 1
|
|
else nn.InstanceNorm2d
|
|
if num_dim == 2
|
|
else nn.InstanceNorm3d
|
|
)
|
|
return nn.Sequential(
|
|
conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
|
|
norm_layer(32, affine=True),
|
|
nn.Flatten(start_dim=1, end_dim=-1),
|
|
nn.Linear(32 * (7**num_dim), num_classes, bias=True),
|
|
)
|
|
|
|
atol, rtol = self._compute_tolerances(device)
|
|
return self._test_conv_model(
|
|
instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
|
|
)
|
|
|
|
@parametrize("num_dim", [1, 2, 3])
|
|
@tf32_off()
|
|
def test_group_norm_model(self, num_dim, device):
|
|
def group_norm_model(num_classes, num_dim):
|
|
conv_layer = (
|
|
nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
|
|
)
|
|
return nn.Sequential(
|
|
conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
|
|
nn.GroupNorm(8, 32, affine=True),
|
|
nn.Flatten(start_dim=1, end_dim=-1),
|
|
nn.Linear(32 * (7**num_dim), num_classes, bias=True),
|
|
)
|
|
|
|
atol, rtol = self._compute_tolerances(device)
|
|
return self._test_conv_model(
|
|
group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
|
|
)
|
|
|
|
@parametrize("num_dim", [1, 2, 3])
|
|
@tf32_off()
|
|
def test_layer_norm_model(self, num_dim, device):
|
|
def layer_norm_model(num_classes, num_dim):
|
|
conv_layer = (
|
|
nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
|
|
)
|
|
normalized_shape = [7] * num_dim
|
|
return nn.Sequential(
|
|
conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
|
|
nn.LayerNorm(normalized_shape, elementwise_affine=True),
|
|
nn.Flatten(start_dim=1, end_dim=-1),
|
|
nn.Linear(32 * (7**num_dim), num_classes, bias=True),
|
|
)
|
|
|
|
atol, rtol = self._compute_tolerances(device)
|
|
return self._test_conv_model(
|
|
layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
|
|
)
|
|
|
|
def test_embedding_model(self, device):
|
|
def embedding_model(num_classes, num_embedding):
|
|
return nn.Sequential(
|
|
nn.Embedding(num_embedding, 15),
|
|
nn.Flatten(start_dim=1, end_dim=-1),
|
|
nn.Linear(375, num_classes, bias=True),
|
|
)
|
|
|
|
return self._test_embedding_model(embedding_model, 16, device)
|
|
|
|
def test_group_norm_error(self, device):
|
|
# group norm has to call native_group_norm. This checks that it hits the same errors
|
|
# that normal group norm would
|
|
|
|
N = 3
|
|
C = 5
|
|
inp = torch.randn(N, C)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"Expected number of channels in input to be divisible"
|
|
):
|
|
F.group_norm(inp, 2) # 5 is not divisible by 2
|
|
|
|
|
|
class TestExpandedWeightModule(TestCase):
|
|
def _do_test(
|
|
self,
|
|
module,
|
|
input,
|
|
args=None,
|
|
kwargs=None,
|
|
batch_first=True,
|
|
atol=None,
|
|
rtol=None,
|
|
):
|
|
args = args or ()
|
|
kwargs = kwargs or {}
|
|
|
|
batch_dim = 0 if batch_first else 1
|
|
batch_size = input.shape[batch_dim]
|
|
diff_input = input.dtype == torch.float or input.dtype == torch.double
|
|
if diff_input:
|
|
input.requires_grad_()
|
|
|
|
with freeze_rng_state():
|
|
# get per sample grads with ExpandedWeights context manager
|
|
actual_res = call_for_per_sample_grads(
|
|
module,
|
|
batch_size=batch_size,
|
|
loss_reduction="sum",
|
|
batch_first=batch_first,
|
|
)(input, *args, **kwargs).sum()
|
|
actual_res.backward()
|
|
actual_grads = []
|
|
for param in module.parameters():
|
|
actual_grads.append(param.grad_sample)
|
|
del param.grad_sample
|
|
if diff_input:
|
|
actual_grads.append(input.grad.clone())
|
|
input.grad = torch.zeros_like(input.grad)
|
|
|
|
# get per sample grads with a for loop
|
|
expected_res = torch.tensor(
|
|
0.0, device=input.device, dtype=actual_res.dtype
|
|
)
|
|
expected_grads = []
|
|
for i in range(batch_size):
|
|
input_slice = input.narrow(batch_dim, i, 1)
|
|
input_slice = input_slice.squeeze(batch_dim)
|
|
|
|
# h's batch dim is always the first dim. Must be contiguous for CUDA
|
|
sliced_args = tree_map_only(
|
|
torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args
|
|
)
|
|
diff_params = module.parameters()
|
|
if diff_input:
|
|
diff_params = chain(diff_params, (input_slice,))
|
|
res = module(
|
|
input_slice.unsqueeze(batch_dim).contiguous(),
|
|
*sliced_args,
|
|
**kwargs,
|
|
).sum()
|
|
out_grads = torch.autograd.grad(
|
|
res, diff_params, torch.ones_like(res), allow_unused=True
|
|
)
|
|
expected_grads.append(out_grads)
|
|
expected_res += res
|
|
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
|
|
if not batch_first:
|
|
expected_grads[-1] = expected_grads[-1].transpose(0, 1)
|
|
self.assertEqual(actual_res, expected_res, atol=atol, rtol=rtol)
|
|
[
|
|
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
|
for (actual, expected) in zip(actual_grads, expected_grads)
|
|
]
|
|
|
|
def _do_test_multi_input(self, module, input):
|
|
class TestModule(nn.Module):
|
|
def __init__(self, module):
|
|
super().__init__()
|
|
self.module = module
|
|
|
|
def forward(self, input):
|
|
return self.module(input) + self.module(input)
|
|
|
|
batch_size = input.shape[0]
|
|
diff_input = input.dtype == torch.float or input.dtype == torch.double
|
|
if diff_input:
|
|
input.requires_grad_()
|
|
with freeze_rng_state():
|
|
# get per sample grads with ExpandedWeights context manager, calling .backward() twice
|
|
test_module = TestModule(module)
|
|
actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")(
|
|
input
|
|
).sum()
|
|
actual_res.backward()
|
|
actual_grads = []
|
|
for param in module.parameters():
|
|
actual_grads.append(param.grad_sample)
|
|
del param.grad_sample
|
|
if diff_input:
|
|
actual_grads.append(input.grad.clone())
|
|
input.grad = torch.zeros_like(input.grad)
|
|
|
|
# get per sample grads with a for loop, running over the input twice
|
|
expected_grads = []
|
|
for i in range(batch_size):
|
|
input_slice = input[i]
|
|
diff_params = module.parameters()
|
|
if diff_input:
|
|
diff_params = chain(diff_params, (input_slice,))
|
|
res = module(input_slice.unsqueeze(0)).sum()
|
|
out_grads = torch.autograd.grad(
|
|
res, diff_params, torch.ones_like(res), allow_unused=True
|
|
)
|
|
expected_grads.append(out_grads)
|
|
expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
|
|
expected_grads = tuple(
|
|
expected_grad
|
|
for expected_grad in expected_grads
|
|
if expected_grad is not None
|
|
)
|
|
assert [
|
|
self.assertEqual(actual, 2 * expected)
|
|
for (actual, expected) in zip(actual_grads, expected_grads)
|
|
]
|
|
|
|
def _do_test_rnn_packed_sequence(
|
|
self, module, input, args=None, kwargs=None, atol=None, rtol=None
|
|
):
|
|
args = args if args is not None else ()
|
|
kwargs = kwargs if kwargs is not None else {}
|
|
|
|
batch_size = max(tuple(input.batch_sizes)).item()
|
|
|
|
with freeze_rng_state():
|
|
# get per sample grads with ExpandedWeights context manager
|
|
actual_res = call_for_per_sample_grads(
|
|
module, batch_size=batch_size, loss_reduction="sum"
|
|
)(input, *args, **kwargs).data.sum()
|
|
actual_res.backward()
|
|
actual_grads = []
|
|
for param in module.parameters():
|
|
self.assertEqual(param.grad_sample.shape[0], batch_size)
|
|
actual_grads.append(param.grad_sample)
|
|
del param.grad_sample
|
|
|
|
input.data.grad = torch.zeros_like(input.data)
|
|
|
|
# compute the per sample grads with a for loop
|
|
expected_res = torch.zeros_like(actual_res)
|
|
expected_grads = []
|
|
padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence(
|
|
input, batch_first=True
|
|
)
|
|
for i in range(len(seq_sizes)):
|
|
input_slice = padded_input[i].narrow(0, 0, seq_sizes[i])
|
|
diff_params = module.parameters()
|
|
batch_dim = 0 if module.m.batch_first else 1
|
|
res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum()
|
|
expected_res += res
|
|
out_grads = torch.autograd.grad(
|
|
res, diff_params, torch.ones_like(res), allow_unused=True
|
|
)
|
|
expected_grads.append(out_grads)
|
|
|
|
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
|
|
self.assertEqual(actual_res, expected_res, atol=atol, rtol=rtol)
|
|
[
|
|
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
|
for (actual, expected) in zip(actual_grads, expected_grads)
|
|
]
|
|
|
|
@modules(
|
|
filter(
|
|
lambda m_info: m_info.module_cls
|
|
in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
|
|
module_db,
|
|
)
|
|
)
|
|
@tf32_off()
|
|
def test_module(self, device, dtype, module_info, training):
|
|
class RNNWrapper(torch.nn.Module):
|
|
def __init__(self, m_cons, args, kwargs):
|
|
super().__init__()
|
|
self.m = m_cons(*args, **kwargs)
|
|
|
|
def forward(self, *inps):
|
|
ret = self.m(*inps)
|
|
assert isinstance(ret, tuple)
|
|
return ret[0]
|
|
|
|
def batch_hidden(h):
|
|
new_h_shape = [1] * (len(h.shape) + 1)
|
|
new_h_shape[1] = 2
|
|
return h.unsqueeze(1).repeat(new_h_shape)
|
|
|
|
module_cls = module_info.module_cls
|
|
atol, rtol = (1e-3, 1e-4) if dtype == torch.float32 else (None, None)
|
|
module_inputs = module_info.module_inputs_func(
|
|
module_info,
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=True,
|
|
training=training,
|
|
with_packed_sequence=True,
|
|
)
|
|
for module_input in module_inputs:
|
|
if module_input.forward_input is None:
|
|
continue
|
|
args, kwargs = (
|
|
module_input.constructor_input.args,
|
|
module_input.constructor_input.kwargs,
|
|
)
|
|
m = RNNWrapper(module_cls, args, kwargs)
|
|
batch_first = m.m.batch_first
|
|
m.to(device).to(dtype)
|
|
|
|
args, kwargs = (
|
|
module_input.forward_input.args,
|
|
module_input.forward_input.kwargs,
|
|
)
|
|
|
|
# if the RNN tests use unbatched inputs--batch the inputs
|
|
input = args[0]
|
|
if isinstance(input, torch.Tensor) and input.dim() == 2:
|
|
input = input.detach()
|
|
new_input_shape = [1] * (len(input.shape) + 1)
|
|
if batch_first:
|
|
new_input_shape[0] = 2
|
|
input = input.repeat(new_input_shape)
|
|
else:
|
|
new_input_shape[1] = 2
|
|
input = input.unsqueeze(1).repeat(new_input_shape)
|
|
|
|
h = args[1] if len(args) > 1 else None
|
|
if h is not None:
|
|
h = (
|
|
batch_hidden(h)
|
|
if isinstance(h, torch.Tensor)
|
|
else tuple(batch_hidden(hx) for hx in h)
|
|
)
|
|
args = list(args)
|
|
args[1] = h
|
|
|
|
if isinstance(input, torch.nn.utils.rnn.PackedSequence):
|
|
self._do_test_rnn_packed_sequence(
|
|
m, input, args[1:], kwargs, atol=atol, rtol=rtol
|
|
)
|
|
else:
|
|
self._do_test(
|
|
m,
|
|
input,
|
|
args[1:],
|
|
kwargs,
|
|
batch_first=batch_first,
|
|
atol=atol,
|
|
rtol=rtol,
|
|
)
|
|
|
|
def test_per_sample_api_failing(self):
|
|
module = nn.Linear(10, 10)
|
|
input = torch.randn(64, 10)
|
|
with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
|
|
call_for_per_sample_grads("fail")(input)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"Batch size passed must be None or an integer"
|
|
):
|
|
call_for_per_sample_grads(module, batch_size=6.4)(input)
|
|
with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
|
|
call_for_per_sample_grads(module, batch_size=-64)(input)
|
|
with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
|
|
loss = call_for_per_sample_grads(module)(input).sum()
|
|
loss.backward() # populate grad_sample fields
|
|
call_for_per_sample_grads(module)(input)
|
|
|
|
module = nn.Linear(10, 10) # reset to not have grad_sample fields
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"Expected loss_reduction argument to be sum or mean"
|
|
):
|
|
call_for_per_sample_grads(module, loss_reduction="")(input)
|
|
|
|
def test_per_sample_api_compute_batch_size(self):
|
|
class CustomModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(5, 5)
|
|
|
|
def forward(self, input1, input2):
|
|
return self.linear(input1) + self.linear(input2)
|
|
|
|
module = CustomModule()
|
|
input1 = torch.randn(4, 5)
|
|
input2 = torch.randn(5, 5)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"found at least one input with batch size 4 and one with batch size 5",
|
|
):
|
|
call_for_per_sample_grads(module)(input1, input2)
|
|
|
|
input2 = torch.randn(4, 5)
|
|
call_for_per_sample_grads(module)(input1, input2)
|
|
|
|
module = CustomModule()
|
|
call_for_per_sample_grads(module)(input1, input2=input2)
|
|
|
|
module = CustomModule()
|
|
call_for_per_sample_grads(module)(input1=input1, input2=input2)
|
|
|
|
def test_per_sample_api_compute_batch_size_not_pytreeable(self):
|
|
@dataclass
|
|
class NonPytreeableTuple:
|
|
elem1: torch.Tensor
|
|
elem2: torch.Tensor
|
|
|
|
class CustomModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(5, 5)
|
|
|
|
def forward(self, input1, input2):
|
|
return self.linear(input1.elem1) + self.linear(input1.elem2)
|
|
|
|
input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5))
|
|
model = CustomModule()
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"ExpandedWeights cannot compute the batch size from the inputs",
|
|
):
|
|
call_for_per_sample_grads(model)(input, "")
|
|
|
|
# would prefer for it to error because input is not pytree-able but that's hard to detect
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Expected ExpandedWeights to have batch size matching input"
|
|
):
|
|
call_for_per_sample_grads(model)(input, torch.randn(5))
|
|
|
|
model = CustomModule() # TODO: functional call bug, sam will fix
|
|
call_for_per_sample_grads(model)(input, torch.randn(4, 5))
|
|
model = CustomModule()
|
|
call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5))
|
|
|
|
|
|
class ContextManagerTests(TestBase):
|
|
def __init__(self, *args, **kwargs):
|
|
self.test_cpu = kwargs.get("test_cpu", True)
|
|
self.test_cuda = kwargs.get("test_cuda", True)
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@property
|
|
def constructor_args(self):
|
|
return self._get_arg("constructor_args", False)
|
|
|
|
def test_context_manager(self, test_case, device):
|
|
kwargs = {"device": device, "dtype": torch.double}
|
|
module = self.constructor(*self.constructor_args).to(**kwargs)
|
|
if "Embedding" in self.get_name():
|
|
kwargs["dtype"] = torch.long
|
|
input = self._get_input().to(**kwargs)
|
|
if len(input.shape) == 0 or input.shape[0] == 0:
|
|
raise unittest.SkipTest(
|
|
"Can't get per sample gradients when no batch dim or batch dim is 0"
|
|
)
|
|
if self.constructor == torch.nn.Linear and len(input.shape) == 1:
|
|
raise unittest.SkipTest(
|
|
"Can't get per sample gradients for input of rank 1"
|
|
)
|
|
test_case._do_test(module, input)
|
|
|
|
def test_context_manager_multiple_inputs(self, test_case, device):
|
|
module = self.constructor(*self.constructor_args).to(device)
|
|
input = self._get_input()
|
|
if len(input.shape) == 0 or input.shape[0] == 0:
|
|
raise unittest.SkipTest(
|
|
"Can't get per sample gradients when no batch dim or batch dim is 0"
|
|
)
|
|
if self.constructor == torch.nn.Linear and len(input.shape) == 1:
|
|
raise unittest.SkipTest(
|
|
"Can't get per sample gradients for input of rank 1"
|
|
)
|
|
test_case._do_test_multi_input(module, input)
|
|
|
|
|
|
def filter_supported_tests(t):
|
|
supported_modules = [
|
|
"Linear",
|
|
"Conv1d",
|
|
"Conv2d",
|
|
"Conv3d",
|
|
"Embedding",
|
|
"LayerNorm",
|
|
"GroupNorm",
|
|
"InstanceNorm",
|
|
]
|
|
if "module_name" in t and t["module_name"] in supported_modules:
|
|
return True
|
|
|
|
|
|
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
|
# These currently use the legacy nn tests
|
|
supported_tests = [
|
|
t for t in module_tests + get_new_module_tests() if filter_supported_tests(t)
|
|
]
|
|
for test_param in supported_tests:
|
|
if "constructor" not in test_param:
|
|
name = test_param.pop("module_name")
|
|
test_param["constructor"] = getattr(nn, name)
|
|
decorator = test_param.pop("decorator", lambda test: test)
|
|
test = ContextManagerTests(**test_param)
|
|
test_name = test.get_name()
|
|
if hasattr(TestExpandedWeightModule, test_name):
|
|
raise RuntimeError("Found two tests with the same name: " + test_name)
|
|
test_name_multi_input = test.get_name() + "_multiple_inputs"
|
|
if hasattr(TestExpandedWeightModule, test_name_multi_input):
|
|
raise RuntimeError("Found two tests with the same name: " + test_name)
|
|
if test.test_cpu:
|
|
setattr(
|
|
TestExpandedWeightModule,
|
|
test_name,
|
|
decorator(lambda self, test=test: test.test_context_manager(self, "cpu")),
|
|
)
|
|
setattr(
|
|
TestExpandedWeightModule,
|
|
test_name_multi_input,
|
|
decorator(
|
|
lambda self, test=test: test.test_context_manager_multiple_inputs(
|
|
self, "cpu"
|
|
)
|
|
),
|
|
)
|
|
if TEST_CUDA and test.test_cuda:
|
|
# since this checks derivatives, only use double for precision
|
|
setattr(
|
|
TestExpandedWeightModule,
|
|
test_name + "_cuda_double",
|
|
decorator(lambda self, test=test: test.test_context_manager(self, "cuda")),
|
|
)
|
|
|
|
# ------------- HELPER FUNCTIONS -----------------
|
|
|
|
|
|
def run_op(op, input, *args, **kwargs):
|
|
r"""
|
|
OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative
|
|
of the weight, not the input, which can't be differentiable since its dtype is int. Calls op,
|
|
using the special ordering that Embedding's OpInfo expects for that case.
|
|
"""
|
|
if op.name == "nn.functional.embedding":
|
|
return op(args[0], input, **kwargs)
|
|
else:
|
|
return op(input, *args, **kwargs)
|
|
|
|
|
|
def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"):
|
|
def expanded_weight_or_clone(arg):
|
|
if is_diff_tensor(arg):
|
|
return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
|
|
return clone_if_tensor(arg)
|
|
|
|
ew_input = clone_if_tensor(sample_input.input)
|
|
ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
|
|
ew_kwargs = {
|
|
name: expanded_weight_or_clone(arg)
|
|
for (name, arg) in sample_input.kwargs.items()
|
|
}
|
|
return ew_input, ew_args, ew_kwargs
|
|
|
|
|
|
def supported_inputs(op, sample_inputs, supported_inputs=True):
|
|
r"""
|
|
ExpandedWeights currently does not support some use cases when there's no batch dimension or
|
|
operations that would cause inter-batch operations. Removes all of the cases it cannot deal with
|
|
"""
|
|
|
|
def filter_fn(input):
|
|
convolutions = [
|
|
"nn.functional.conv1d",
|
|
"nn.functional.conv2d",
|
|
"nn.functional.conv3d",
|
|
]
|
|
batched_input_size = dict(zip(convolutions, [3, 4, 5]))
|
|
if op.name == "nn.functional.linear":
|
|
is_supported_input = (
|
|
input.input.dim() > 1
|
|
) # input of rank 1 means no batch dim
|
|
elif op.name == "nn.functional.layer_norm":
|
|
normalized_shape = input.args[0]
|
|
is_supported_input = (
|
|
input.input.shape != normalized_shape
|
|
) # would cause inter-batch operations
|
|
elif op.name in convolutions:
|
|
# currently can't deal with padding computation on Python level
|
|
is_supported_input = input.input.dim() == batched_input_size[op.name]
|
|
elif op.name == "nn.functional.embedding":
|
|
idx = input.args[0]
|
|
is_supported_input = len(idx.shape) > 1 # there's no batch size
|
|
else:
|
|
is_supported_input = True
|
|
is_supported_input = (
|
|
is_supported_input and input.input.shape[0] > 0
|
|
) # 0 is not a valid batch size
|
|
return is_supported_input if supported_inputs else not is_supported_input
|
|
|
|
return [input for input in sample_inputs if filter_fn(input)]
|
|
|
|
|
|
def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs):
|
|
# get per sample grads by getting derivative for each input in a for loop
|
|
per_sample_grad = []
|
|
for i in range(batch_size):
|
|
per_sample_input = input[i]
|
|
result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs))
|
|
diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
|
|
diff_input_list = [
|
|
i
|
|
for i in diff_input_list
|
|
if isinstance(i, torch.Tensor) and i.requires_grad
|
|
]
|
|
per_sample_grad.append(
|
|
torch.autograd.grad(
|
|
result, diff_input_list, torch.ones_like(result), allow_unused=True
|
|
)
|
|
)
|
|
if len(per_sample_grad) == batch_size:
|
|
per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad))
|
|
return per_sample_grad
|
|
|
|
|
|
def is_diff_tensor(t):
|
|
return isinstance(t, ExpandedWeight) or (
|
|
isinstance(t, torch.Tensor) and t.requires_grad
|
|
)
|
|
|
|
|
|
def clone_if_tensor(t):
|
|
if isinstance(t, torch.Tensor):
|
|
res = torch.clone(t).detach()
|
|
res.requires_grad_(t.requires_grad)
|
|
return res
|
|
else:
|
|
return t
|
|
|
|
|
|
instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
|
|
instantiate_device_type_tests(TestExpandedWeightFunctional, globals())
|
|
instantiate_device_type_tests(TestExpandedWeightModule, globals())
|
|
if __name__ == "__main__":
|
|
run_tests()
|