mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	This reverts commit 310f901a71e53688866b14bb2f2b4c8eef9979b3. Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3152337518))
		
			
				
	
	
		
			1161 lines
		
	
	
		
			45 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1161 lines
		
	
	
		
			45 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["module: nn"]
 | 
						|
import unittest
 | 
						|
from dataclasses import dataclass
 | 
						|
from functools import partial
 | 
						|
from itertools import chain, product
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
import torch.nn.functional as F
 | 
						|
from torch.nn import CrossEntropyLoss
 | 
						|
from torch.nn.utils._expanded_weights import ExpandedWeight
 | 
						|
from torch.nn.utils._expanded_weights.expanded_weights_utils import (
 | 
						|
    forward_helper,
 | 
						|
    set_grad_sample_if_exists,
 | 
						|
    standard_kwargs,
 | 
						|
    sum_over_all_but_batch_and_last_n,
 | 
						|
    unpack_expanded_weight_or_tensor,
 | 
						|
)
 | 
						|
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
 | 
						|
from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off
 | 
						|
from torch.testing._internal.common_device_type import (
 | 
						|
    instantiate_device_type_tests,
 | 
						|
    OpDTypes,
 | 
						|
    ops,
 | 
						|
)
 | 
						|
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
 | 
						|
from torch.testing._internal.common_modules import module_db, modules
 | 
						|
from torch.testing._internal.common_nn import (
 | 
						|
    get_new_module_tests,
 | 
						|
    module_tests,
 | 
						|
    TestBase,
 | 
						|
)
 | 
						|
from torch.testing._internal.common_utils import (
 | 
						|
    freeze_rng_state,
 | 
						|
    make_tensor,
 | 
						|
    parametrize,
 | 
						|
    run_tests,
 | 
						|
    skipIfTorchDynamo,
 | 
						|
    TestCase,
 | 
						|
)
 | 
						|
from torch.utils._pytree import tree_map_only
 | 
						|
 | 
						|
 | 
						|
class TestContext:
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class TestExpandedWeightHelperFunction(TestCase):
 | 
						|
    def test_forward_helper(self, device):
 | 
						|
        input = torch.randn(3, 4, device=device)
 | 
						|
        weight = torch.randn(5, 4, device=device)
 | 
						|
        bias = torch.randn(5, device=device)
 | 
						|
        for weight_batched, bias_batched in product([True, False], [True, False]):
 | 
						|
            maybe_batched_weight = weight
 | 
						|
            maybe_batched_bias = bias
 | 
						|
            if weight_batched:
 | 
						|
                maybe_batched_weight = ExpandedWeight(
 | 
						|
                    weight.clone().requires_grad_(), 3, loss_reduction="sum"
 | 
						|
                )
 | 
						|
            if bias_batched:
 | 
						|
                maybe_batched_bias = ExpandedWeight(
 | 
						|
                    bias.clone().requires_grad_(), 3, loss_reduction="sum"
 | 
						|
                )
 | 
						|
            args = (input, maybe_batched_weight, maybe_batched_bias)
 | 
						|
            expanded_args, expanded_kwargs = standard_kwargs(("bias",), args)
 | 
						|
            res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
 | 
						|
            expected = nn.functional.linear(input, weight, bias)
 | 
						|
            self.assertEqual(res, expected)
 | 
						|
 | 
						|
            self.assertEqual(len(expanded_args), 2)
 | 
						|
            assert expanded_args[0] is args[0]  # avoids property checks in assertEquals
 | 
						|
            assert expanded_args[1] is args[1]  # avoids property checks in assertEquals
 | 
						|
            self.assertEqual(len(expanded_kwargs), 1)
 | 
						|
            assert (
 | 
						|
                expanded_kwargs["bias"] is args[2]
 | 
						|
            )  # avoids property checks in assertEquals
 | 
						|
 | 
						|
    def test_forward_helper_failure_args(self, device):
 | 
						|
        weight = torch.randn(5, 4, device=device)
 | 
						|
        bias = torch.randn(5, device=device)
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError, r"do not support inputs that are also ExpandedWeights."
 | 
						|
        ):
 | 
						|
            input = ExpandedWeight(
 | 
						|
                torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum"
 | 
						|
            )
 | 
						|
            expanded_args, expanded_kwargs = standard_kwargs(
 | 
						|
                ("bias",), (input, weight, bias)
 | 
						|
            )
 | 
						|
            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError, r"requires a Tensor as the first input"
 | 
						|
        ):
 | 
						|
            expanded_args, expanded_kwargs = standard_kwargs(
 | 
						|
                ("bias",), (3, weight, bias)
 | 
						|
            )
 | 
						|
            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError, r"requires a batch dimension but got an input of size 0"
 | 
						|
        ):
 | 
						|
            expanded_args, expanded_kwargs = standard_kwargs(
 | 
						|
                ("bias",), (torch.tensor(3), weight, bias)
 | 
						|
            )
 | 
						|
            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError, r"0 is not a valid batch size for Expanded Weights"
 | 
						|
        ):
 | 
						|
            expanded_args, expanded_kwargs = standard_kwargs(
 | 
						|
                ("bias",), (torch.randn(0, 1, 2), weight, bias)
 | 
						|
            )
 | 
						|
            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
 | 
						|
        input = torch.randn(3, 4)
 | 
						|
        for weight_batched, bias_batched in product([True, False], [True, False]):
 | 
						|
            if not weight_batched and not bias_batched:
 | 
						|
                continue
 | 
						|
            maybe_batched_weight = weight
 | 
						|
            maybe_batched_bias = bias
 | 
						|
            if weight_batched:
 | 
						|
                maybe_batched_weight = ExpandedWeight(
 | 
						|
                    weight.clone().requires_grad_(), 4, loss_reduction="sum"
 | 
						|
                )
 | 
						|
            if bias_batched:
 | 
						|
                maybe_batched_bias = ExpandedWeight(
 | 
						|
                    bias.clone().requires_grad_(), 4, loss_reduction="sum"
 | 
						|
                )
 | 
						|
            with self.assertRaisesRegex(
 | 
						|
                RuntimeError,
 | 
						|
                r"Expected ExpandedWeights to have batch size matching input",
 | 
						|
            ):
 | 
						|
                expanded_args, expanded_kwargs = standard_kwargs(
 | 
						|
                    ("bias",), (input, maybe_batched_weight, maybe_batched_bias)
 | 
						|
                )
 | 
						|
                forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
 | 
						|
 | 
						|
    def test_set_grad_sample_if_exists(self, device):
 | 
						|
        def test_fn(a):
 | 
						|
            return grad_sample
 | 
						|
 | 
						|
        orig_weight = torch.randn(4, device=device, requires_grad=True)
 | 
						|
        expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
 | 
						|
        grad_sample = torch.randn(3)
 | 
						|
        set_grad_sample_if_exists(expanded_weight, test_fn)
 | 
						|
        self.assertTrue(hasattr(orig_weight, "grad_sample"))
 | 
						|
        self.assertEqual(orig_weight.grad_sample, grad_sample)
 | 
						|
 | 
						|
        basic_tensor = torch.randn(4, device=device)
 | 
						|
        set_grad_sample_if_exists(basic_tensor, test_fn)
 | 
						|
        self.assertFalse(hasattr(basic_tensor, "grad_sample"))
 | 
						|
 | 
						|
        non_tensor = 3
 | 
						|
        set_grad_sample_if_exists(non_tensor, test_fn)
 | 
						|
        self.assertFalse(hasattr(non_tensor, "grad_sample"))
 | 
						|
 | 
						|
    def test_set_grad_sample_if_exists_failure(self, device):
 | 
						|
        def test_fn(a):
 | 
						|
            return True
 | 
						|
 | 
						|
        grad_tensor = torch.randn(4, requires_grad=True, device=device)
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
 | 
						|
        ):
 | 
						|
            set_grad_sample_if_exists(grad_tensor, test_fn)
 | 
						|
 | 
						|
    def test_unpack_expanded_weight_or_tensor(self, device):
 | 
						|
        input = torch.randn(3, requires_grad=True, device=device)
 | 
						|
        self.assertEqual(
 | 
						|
            input,
 | 
						|
            unpack_expanded_weight_or_tensor(
 | 
						|
                ExpandedWeight(input, 3, loss_reduction="sum")
 | 
						|
            ),
 | 
						|
        )
 | 
						|
 | 
						|
        input.requires_grad_(False)
 | 
						|
        self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
 | 
						|
        self.assertTrue(unpack_expanded_weight_or_tensor(4) is None)
 | 
						|
 | 
						|
    def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
 | 
						|
        input = torch.randn(3, requires_grad=True, device=device)
 | 
						|
        self.assertTrue(
 | 
						|
            unpack_expanded_weight_or_tensor(
 | 
						|
                ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        input.requires_grad_(False)
 | 
						|
        self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
 | 
						|
        self.assertTrue(
 | 
						|
            unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None
 | 
						|
        )
 | 
						|
 | 
						|
    def test_unpack_expanded_weight_or_tensor_failure(self, device):
 | 
						|
        input = torch.randn(3, requires_grad=True, device=device)
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
 | 
						|
        ):
 | 
						|
            unpack_expanded_weight_or_tensor(input)
 | 
						|
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
 | 
						|
        ):
 | 
						|
            unpack_expanded_weight_or_tensor(input, lambda x: x is input)
 | 
						|
 | 
						|
    def test_sum_over_all_but_batch_and_last_n(self, device):
 | 
						|
        input = torch.randn(1, 2, 3, 4, 5, device=device)
 | 
						|
        res = sum_over_all_but_batch_and_last_n(input, 2)
 | 
						|
        expected = input.sum((1, 2))
 | 
						|
        self.assertEqual(res, expected)
 | 
						|
 | 
						|
        res = sum_over_all_but_batch_and_last_n(input, 0)
 | 
						|
        expected = input.sum((1, 2, 3, 4))
 | 
						|
        self.assertEqual(res, expected)
 | 
						|
 | 
						|
        res = sum_over_all_but_batch_and_last_n(input, 4)
 | 
						|
        self.assertEqual(res, input)
 | 
						|
 | 
						|
 | 
						|
class TestExpandedWeightFunctional(TestCase):
 | 
						|
    def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction):
 | 
						|
        input = sample_input.input
 | 
						|
        args = sample_input.args
 | 
						|
        kwargs = sample_input.kwargs
 | 
						|
        batch_size = input.shape[0] if len(input.shape) > 1 else 1
 | 
						|
 | 
						|
        # get per sample grads with ExpandedWeights objects
 | 
						|
        loss_reduction = "sum" if reduction == torch.sum else "mean"
 | 
						|
        (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
 | 
						|
            sample_input, batch_size, loss_reduction
 | 
						|
        )
 | 
						|
        diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
 | 
						|
        diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
 | 
						|
        diff_input_list = [
 | 
						|
            i.orig_weight if isinstance(i, ExpandedWeight) else i
 | 
						|
            for i in diff_input_list
 | 
						|
        ]
 | 
						|
        if not diff_input_list:
 | 
						|
            return
 | 
						|
        result = run_op(op, ew_input, *ew_args, **ew_kwargs)
 | 
						|
        reduction(
 | 
						|
            result
 | 
						|
        ).backward()  # grad doesn't work with ExpandedWeight because it calls __torch_function__
 | 
						|
        expanded_weight_grad = tuple(
 | 
						|
            i.grad_sample if hasattr(i, "grad_sample") else i.grad
 | 
						|
            for i in diff_input_list
 | 
						|
        )
 | 
						|
 | 
						|
        # get per sample grads with for loop
 | 
						|
        func = partial(run_op, op)
 | 
						|
 | 
						|
        per_sample_grad = for_loop_per_sample_grad(
 | 
						|
            batch_size, reduction, input, func, *args, **kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        # check equality
 | 
						|
        self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
 | 
						|
        if loss_reduction == "mean":
 | 
						|
            # don't check equality of `input.grad`s since these vanilla tensors won't be scaled
 | 
						|
            expanded_weight_grad = expanded_weight_grad[1:]
 | 
						|
            per_sample_grad = per_sample_grad[1:]
 | 
						|
        for result_grad, expected_grad in zip(expanded_weight_grad, per_sample_grad):
 | 
						|
            self.assertEqual(result_grad, expected_grad)
 | 
						|
 | 
						|
    @ops(
 | 
						|
        filter(lambda op: op.supports_expanded_weight, op_db),
 | 
						|
        dtypes=OpDTypes.supported,
 | 
						|
        allowed_dtypes=(torch.double,),
 | 
						|
    )
 | 
						|
    def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op):
 | 
						|
        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
 | 
						|
        for sample_input in supported_inputs(op, sample_inputs):
 | 
						|
            if (
 | 
						|
                op.name == "nn.functional.embedding"
 | 
						|
            ):  # embedding flips its argument order for autograd tests
 | 
						|
                sample_input = SampleInput(
 | 
						|
                    sample_input.args[0],
 | 
						|
                    args=(sample_input.input,),
 | 
						|
                    kwargs=sample_input.kwargs,
 | 
						|
                )
 | 
						|
 | 
						|
            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
 | 
						|
 | 
						|
    @ops(
 | 
						|
        filter(lambda op: op.supports_expanded_weight, op_db),
 | 
						|
        dtypes=OpDTypes.supported,
 | 
						|
        allowed_dtypes=(torch.double,),
 | 
						|
    )
 | 
						|
    def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op):
 | 
						|
        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
 | 
						|
        for sample_input in supported_inputs(op, sample_inputs):
 | 
						|
            if (
 | 
						|
                op.name == "nn.functional.embedding"
 | 
						|
            ):  # embedding flips its argument order for autograd tests
 | 
						|
                sample_input = SampleInput(
 | 
						|
                    sample_input.args[0],
 | 
						|
                    args=(sample_input.input,),
 | 
						|
                    kwargs=sample_input.kwargs,
 | 
						|
                )
 | 
						|
 | 
						|
            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
 | 
						|
 | 
						|
    @ops(
 | 
						|
        filter(lambda op: op.supports_expanded_weight, op_db),
 | 
						|
        dtypes=OpDTypes.supported,
 | 
						|
        allowed_dtypes=(torch.double,),
 | 
						|
    )
 | 
						|
    def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op):
 | 
						|
        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
 | 
						|
        for sample_input in supported_inputs(op, sample_inputs):
 | 
						|
            if (
 | 
						|
                op.name == "nn.functional.embedding"
 | 
						|
            ):  # embedding flips its argument order for autograd tests
 | 
						|
                sample_input = SampleInput(
 | 
						|
                    sample_input.args[0],
 | 
						|
                    args=(sample_input.input,),
 | 
						|
                    kwargs=sample_input.kwargs,
 | 
						|
                )
 | 
						|
            sample_input.input.requires_grad_(False)
 | 
						|
 | 
						|
            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
 | 
						|
 | 
						|
    @skipIfTorchDynamo("Checking error message doesn't work with dynamo")
 | 
						|
    @ops(
 | 
						|
        filter(lambda op: op.supports_expanded_weight, op_db),
 | 
						|
        dtypes=OpDTypes.supported,
 | 
						|
        allowed_dtypes=(torch.double,),
 | 
						|
    )
 | 
						|
    def test_unsupported_expand_weights(self, device, dtype, op):
 | 
						|
        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
 | 
						|
        unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False)
 | 
						|
        for sample_input in unsupported_inputs:
 | 
						|
            with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"):
 | 
						|
                if (
 | 
						|
                    op.name == "nn.functional.embedding"
 | 
						|
                ):  # embedding flips its argument order for autograd tests
 | 
						|
                    sample_input = SampleInput(
 | 
						|
                        sample_input.args[0],
 | 
						|
                        args=(sample_input.input,),
 | 
						|
                        kwargs=sample_input.kwargs,
 | 
						|
                    )
 | 
						|
                input = sample_input.input
 | 
						|
 | 
						|
                batch_size = input.shape[0] if len(input.shape) > 1 else 1
 | 
						|
 | 
						|
                # get per sample grads with ExpandedWeights objects
 | 
						|
                (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
 | 
						|
                    sample_input, batch_size
 | 
						|
                )
 | 
						|
                result = run_op(op, ew_input, *ew_args, **ew_kwargs)
 | 
						|
                diff_input_list = (
 | 
						|
                    (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
 | 
						|
                )
 | 
						|
                diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
 | 
						|
                diff_input_list = [
 | 
						|
                    i.orig_weight if isinstance(i, ExpandedWeight) else i
 | 
						|
                    for i in diff_input_list
 | 
						|
                ]
 | 
						|
                result.sum().backward()  # grad doesn't work with ExpandedWeight because it calls __torch_function__
 | 
						|
 | 
						|
    @ops(
 | 
						|
        filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported
 | 
						|
    )
 | 
						|
    def test_expanded_weight_forward(self, device, dtype, op):
 | 
						|
        sample_inputs = op.sample_inputs(device, dtype)
 | 
						|
        for sample_input in supported_inputs(op, sample_inputs):
 | 
						|
            if (
 | 
						|
                op.name == "nn.functional.embedding"
 | 
						|
            ):  # embedding flips its argument order for autograd tests
 | 
						|
                sample_input = SampleInput(
 | 
						|
                    sample_input.args[0].clone(),
 | 
						|
                    args=(sample_input.input.clone(),),
 | 
						|
                    kwargs=sample_input.kwargs,
 | 
						|
                )
 | 
						|
                if (
 | 
						|
                    "cuda" in device
 | 
						|
                    and "max_norm" in sample_input.kwargs
 | 
						|
                    and "padding_idx" in sample_input.kwargs
 | 
						|
                ):
 | 
						|
                    self.skipTest(
 | 
						|
                        "embedding is non-determinstic in this case, see issue #74679"
 | 
						|
                    )
 | 
						|
            batch_size = (
 | 
						|
                sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
 | 
						|
            )
 | 
						|
            for loss_reduction in ["sum", "mean"]:
 | 
						|
                (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
 | 
						|
                    sample_input, batch_size, loss_reduction
 | 
						|
                )
 | 
						|
                expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
 | 
						|
                normal_result = run_op(
 | 
						|
                    op, sample_input.input, *sample_input.args, **sample_input.kwargs
 | 
						|
                )
 | 
						|
                self.assertEqual(expanded_weight_result, normal_result)
 | 
						|
 | 
						|
    def test_expanded_weight_error(self, device):
 | 
						|
        batch_size = 3
 | 
						|
        sample_input = make_tensor(
 | 
						|
            (batch_size, 4), dtype=torch.float32, device=device, requires_grad=True
 | 
						|
        )
 | 
						|
        sample_weight = make_tensor(
 | 
						|
            (4), dtype=torch.float32, device=device, requires_grad=True
 | 
						|
        )
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError, r"Expanded Weights encountered but cannot handle function"
 | 
						|
        ):
 | 
						|
            torch.add(
 | 
						|
                sample_input,
 | 
						|
                ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"),
 | 
						|
            )
 | 
						|
 | 
						|
    def _test_embedding_model(self, model, num_embedding, device):
 | 
						|
        batch_size = 32
 | 
						|
        input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device)
 | 
						|
        return self._test_model(
 | 
						|
            partial(model, num_embedding=num_embedding), batch_size, input, device
 | 
						|
        )
 | 
						|
 | 
						|
    def _test_conv_model(
 | 
						|
        self,
 | 
						|
        model,
 | 
						|
        input_size,
 | 
						|
        num_dim,
 | 
						|
        device,
 | 
						|
        loss_reduction="sum",
 | 
						|
        atol=1e-4,
 | 
						|
        rtol=5e-5,
 | 
						|
    ):
 | 
						|
        batch_size = 32
 | 
						|
        input_ending = [input_size] * num_dim
 | 
						|
        input = torch.randn([batch_size, 3] + input_ending, device=device)
 | 
						|
        return self._test_model(
 | 
						|
            partial(model, num_dim=num_dim),
 | 
						|
            batch_size,
 | 
						|
            input,
 | 
						|
            device,
 | 
						|
            loss_reduction,
 | 
						|
            atol,
 | 
						|
            rtol,
 | 
						|
        )
 | 
						|
 | 
						|
    def _test_model(
 | 
						|
        self,
 | 
						|
        model,
 | 
						|
        batch_size,
 | 
						|
        input,
 | 
						|
        device,
 | 
						|
        loss_reduction="sum",
 | 
						|
        atol=1e-4,
 | 
						|
        rtol=5e-5,
 | 
						|
    ):
 | 
						|
        model = model(10).to(device)
 | 
						|
        targets = torch.randint(0, 10, (batch_size,), device=device)
 | 
						|
        criterion = CrossEntropyLoss(reduction=loss_reduction)
 | 
						|
        result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input)
 | 
						|
        loss = criterion(result, targets)
 | 
						|
        loss.backward()
 | 
						|
        result = []
 | 
						|
        for weight in model.parameters():
 | 
						|
            result.append(weight.grad_sample)
 | 
						|
            del weight.grad_sample
 | 
						|
 | 
						|
        expected = []
 | 
						|
        for i in range(batch_size):
 | 
						|
            loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0))
 | 
						|
            expected.append(
 | 
						|
                torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))
 | 
						|
            )
 | 
						|
 | 
						|
        expected = [torch.stack(grad) for grad in zip(*expected)]
 | 
						|
        for res, exp in zip(result, expected):
 | 
						|
            self.assertEqual(res, exp, atol=atol, rtol=rtol)
 | 
						|
 | 
						|
    def _compute_tolerances(self, device):
 | 
						|
        is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(
 | 
						|
            0
 | 
						|
        ) == (8, 6)
 | 
						|
        return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5)
 | 
						|
 | 
						|
    @tf32_off()
 | 
						|
    def test_cnn_model_sum(self, device):
 | 
						|
        def convnet(num_classes, num_dim):
 | 
						|
            return nn.Sequential(
 | 
						|
                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
 | 
						|
                nn.ReLU(),
 | 
						|
                nn.AvgPool2d(kernel_size=2, stride=2),
 | 
						|
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
 | 
						|
                nn.ReLU(),
 | 
						|
                nn.AvgPool2d(kernel_size=2, stride=2),
 | 
						|
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
 | 
						|
                nn.ReLU(),
 | 
						|
                nn.AvgPool2d(kernel_size=2, stride=2),
 | 
						|
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
 | 
						|
                nn.ReLU(),
 | 
						|
                nn.AdaptiveAvgPool2d((1, 1)),
 | 
						|
                nn.Flatten(start_dim=1, end_dim=-1),
 | 
						|
                nn.Linear(128, num_classes, bias=True),
 | 
						|
            )
 | 
						|
 | 
						|
        atol, rtol = self._compute_tolerances(device)
 | 
						|
        return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol)
 | 
						|
 | 
						|
    @tf32_off()
 | 
						|
    def test_cnn_model_mean(self, device):
 | 
						|
        def convnet(num_classes, num_dim):
 | 
						|
            return nn.Sequential(
 | 
						|
                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
 | 
						|
                nn.ReLU(),
 | 
						|
                nn.AvgPool2d(kernel_size=2, stride=2),
 | 
						|
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
 | 
						|
                nn.ReLU(),
 | 
						|
                nn.AvgPool2d(kernel_size=2, stride=2),
 | 
						|
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
 | 
						|
                nn.ReLU(),
 | 
						|
                nn.AvgPool2d(kernel_size=2, stride=2),
 | 
						|
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
 | 
						|
                nn.ReLU(),
 | 
						|
                nn.AdaptiveAvgPool2d((1, 1)),
 | 
						|
                nn.Flatten(start_dim=1, end_dim=-1),
 | 
						|
                nn.Linear(128, num_classes, bias=True),
 | 
						|
            )
 | 
						|
 | 
						|
        atol, rtol = self._compute_tolerances(device)
 | 
						|
        return self._test_conv_model(
 | 
						|
            convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol
 | 
						|
        )
 | 
						|
 | 
						|
    @parametrize("num_dim", [1, 2, 3])
 | 
						|
    @tf32_off()
 | 
						|
    def test_instance_norm_model(self, num_dim, device):
 | 
						|
        def instance_norm_model(num_classes, num_dim):
 | 
						|
            conv_layer = (
 | 
						|
                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
 | 
						|
            )
 | 
						|
            norm_layer = (
 | 
						|
                nn.InstanceNorm1d
 | 
						|
                if num_dim == 1
 | 
						|
                else nn.InstanceNorm2d
 | 
						|
                if num_dim == 2
 | 
						|
                else nn.InstanceNorm3d
 | 
						|
            )
 | 
						|
            return nn.Sequential(
 | 
						|
                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
 | 
						|
                norm_layer(32, affine=True),
 | 
						|
                nn.Flatten(start_dim=1, end_dim=-1),
 | 
						|
                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
 | 
						|
            )
 | 
						|
 | 
						|
        atol, rtol = self._compute_tolerances(device)
 | 
						|
        return self._test_conv_model(
 | 
						|
            instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
 | 
						|
        )
 | 
						|
 | 
						|
    @parametrize("num_dim", [1, 2, 3])
 | 
						|
    @tf32_off()
 | 
						|
    def test_group_norm_model(self, num_dim, device):
 | 
						|
        def group_norm_model(num_classes, num_dim):
 | 
						|
            conv_layer = (
 | 
						|
                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
 | 
						|
            )
 | 
						|
            return nn.Sequential(
 | 
						|
                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
 | 
						|
                nn.GroupNorm(8, 32, affine=True),
 | 
						|
                nn.Flatten(start_dim=1, end_dim=-1),
 | 
						|
                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
 | 
						|
            )
 | 
						|
 | 
						|
        atol, rtol = self._compute_tolerances(device)
 | 
						|
        return self._test_conv_model(
 | 
						|
            group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
 | 
						|
        )
 | 
						|
 | 
						|
    @parametrize("num_dim", [1, 2, 3])
 | 
						|
    @tf32_off()
 | 
						|
    def test_layer_norm_model(self, num_dim, device):
 | 
						|
        def layer_norm_model(num_classes, num_dim):
 | 
						|
            conv_layer = (
 | 
						|
                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
 | 
						|
            )
 | 
						|
            normalized_shape = [7] * num_dim
 | 
						|
            return nn.Sequential(
 | 
						|
                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
 | 
						|
                nn.LayerNorm(normalized_shape, elementwise_affine=True),
 | 
						|
                nn.Flatten(start_dim=1, end_dim=-1),
 | 
						|
                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
 | 
						|
            )
 | 
						|
 | 
						|
        atol, rtol = self._compute_tolerances(device)
 | 
						|
        return self._test_conv_model(
 | 
						|
            layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
 | 
						|
        )
 | 
						|
 | 
						|
    def test_embedding_model(self, device):
 | 
						|
        def embedding_model(num_classes, num_embedding):
 | 
						|
            return nn.Sequential(
 | 
						|
                nn.Embedding(num_embedding, 15),
 | 
						|
                nn.Flatten(start_dim=1, end_dim=-1),
 | 
						|
                nn.Linear(375, num_classes, bias=True),
 | 
						|
            )
 | 
						|
 | 
						|
        return self._test_embedding_model(embedding_model, 16, device)
 | 
						|
 | 
						|
    def test_group_norm_error(self, device):
 | 
						|
        # group norm has to call native_group_norm. This checks that it hits the same errors
 | 
						|
        # that normal group norm would
 | 
						|
 | 
						|
        N = 3
 | 
						|
        C = 5
 | 
						|
        inp = torch.randn(N, C)
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError, r"Expected number of channels in input to be divisible"
 | 
						|
        ):
 | 
						|
            F.group_norm(inp, 2)  # 5 is not divisible by 2
 | 
						|
 | 
						|
 | 
						|
class TestExpandedWeightModule(TestCase):
 | 
						|
    def _do_test(
 | 
						|
        self,
 | 
						|
        module,
 | 
						|
        input,
 | 
						|
        args=None,
 | 
						|
        kwargs=None,
 | 
						|
        batch_first=True,
 | 
						|
        atol=None,
 | 
						|
        rtol=None,
 | 
						|
    ):
 | 
						|
        args = args or ()
 | 
						|
        kwargs = kwargs or {}
 | 
						|
 | 
						|
        batch_dim = 0 if batch_first else 1
 | 
						|
        batch_size = input.shape[batch_dim]
 | 
						|
        diff_input = input.dtype == torch.float or input.dtype == torch.double
 | 
						|
        if diff_input:
 | 
						|
            input.requires_grad_()
 | 
						|
 | 
						|
        with freeze_rng_state():
 | 
						|
            # get per sample grads with ExpandedWeights context manager
 | 
						|
            actual_res = call_for_per_sample_grads(
 | 
						|
                module,
 | 
						|
                batch_size=batch_size,
 | 
						|
                loss_reduction="sum",
 | 
						|
                batch_first=batch_first,
 | 
						|
            )(input, *args, **kwargs).sum()
 | 
						|
            actual_res.backward()
 | 
						|
            actual_grads = []
 | 
						|
            for param in module.parameters():
 | 
						|
                actual_grads.append(param.grad_sample)
 | 
						|
                del param.grad_sample
 | 
						|
            if diff_input:
 | 
						|
                actual_grads.append(input.grad.clone())
 | 
						|
                input.grad = torch.zeros_like(input.grad)
 | 
						|
 | 
						|
            # get per sample grads with a for loop
 | 
						|
            expected_res = torch.tensor(
 | 
						|
                0.0, device=input.device, dtype=actual_res.dtype
 | 
						|
            )
 | 
						|
            expected_grads = []
 | 
						|
            for i in range(batch_size):
 | 
						|
                input_slice = input.narrow(batch_dim, i, 1)
 | 
						|
                input_slice = input_slice.squeeze(batch_dim)
 | 
						|
 | 
						|
                # h's batch dim is always the first dim. Must be contiguous for CUDA
 | 
						|
                sliced_args = tree_map_only(
 | 
						|
                    torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args
 | 
						|
                )
 | 
						|
                diff_params = module.parameters()
 | 
						|
                if diff_input:
 | 
						|
                    diff_params = chain(diff_params, (input_slice,))
 | 
						|
                res = module(
 | 
						|
                    input_slice.unsqueeze(batch_dim).contiguous(),
 | 
						|
                    *sliced_args,
 | 
						|
                    **kwargs,
 | 
						|
                ).sum()
 | 
						|
                out_grads = torch.autograd.grad(
 | 
						|
                    res, diff_params, torch.ones_like(res), allow_unused=True
 | 
						|
                )
 | 
						|
                expected_grads.append(out_grads)
 | 
						|
                expected_res += res
 | 
						|
            expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
 | 
						|
            if not batch_first:
 | 
						|
                expected_grads[-1] = expected_grads[-1].transpose(0, 1)
 | 
						|
        self.assertEqual(actual_res, expected_res, atol=atol, rtol=rtol)
 | 
						|
        [
 | 
						|
            self.assertEqual(actual, expected, atol=atol, rtol=rtol)
 | 
						|
            for (actual, expected) in zip(actual_grads, expected_grads)
 | 
						|
        ]
 | 
						|
 | 
						|
    def _do_test_multi_input(self, module, input):
 | 
						|
        class TestModule(nn.Module):
 | 
						|
            def __init__(self, module):
 | 
						|
                super().__init__()
 | 
						|
                self.module = module
 | 
						|
 | 
						|
            def forward(self, input):
 | 
						|
                return self.module(input) + self.module(input)
 | 
						|
 | 
						|
        batch_size = input.shape[0]
 | 
						|
        diff_input = input.dtype == torch.float or input.dtype == torch.double
 | 
						|
        if diff_input:
 | 
						|
            input.requires_grad_()
 | 
						|
        with freeze_rng_state():
 | 
						|
            # get per sample grads with ExpandedWeights context manager, calling .backward() twice
 | 
						|
            test_module = TestModule(module)
 | 
						|
            actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")(
 | 
						|
                input
 | 
						|
            ).sum()
 | 
						|
            actual_res.backward()
 | 
						|
            actual_grads = []
 | 
						|
            for param in module.parameters():
 | 
						|
                actual_grads.append(param.grad_sample)
 | 
						|
                del param.grad_sample
 | 
						|
            if diff_input:
 | 
						|
                actual_grads.append(input.grad.clone())
 | 
						|
                input.grad = torch.zeros_like(input.grad)
 | 
						|
 | 
						|
            # get per sample grads with a for loop, running over the input twice
 | 
						|
            expected_grads = []
 | 
						|
            for i in range(batch_size):
 | 
						|
                input_slice = input[i]
 | 
						|
                diff_params = module.parameters()
 | 
						|
                if diff_input:
 | 
						|
                    diff_params = chain(diff_params, (input_slice,))
 | 
						|
                res = module(input_slice.unsqueeze(0)).sum()
 | 
						|
                out_grads = torch.autograd.grad(
 | 
						|
                    res, diff_params, torch.ones_like(res), allow_unused=True
 | 
						|
                )
 | 
						|
                expected_grads.append(out_grads)
 | 
						|
        expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
 | 
						|
        expected_grads = tuple(
 | 
						|
            expected_grad
 | 
						|
            for expected_grad in expected_grads
 | 
						|
            if expected_grad is not None
 | 
						|
        )
 | 
						|
        assert [
 | 
						|
            self.assertEqual(actual, 2 * expected)
 | 
						|
            for (actual, expected) in zip(actual_grads, expected_grads)
 | 
						|
        ]
 | 
						|
 | 
						|
    def _do_test_rnn_packed_sequence(
 | 
						|
        self, module, input, args=None, kwargs=None, atol=None, rtol=None
 | 
						|
    ):
 | 
						|
        args = args if args is not None else ()
 | 
						|
        kwargs = kwargs if kwargs is not None else {}
 | 
						|
 | 
						|
        batch_size = max(tuple(input.batch_sizes)).item()
 | 
						|
 | 
						|
        with freeze_rng_state():
 | 
						|
            # get per sample grads with ExpandedWeights context manager
 | 
						|
            actual_res = call_for_per_sample_grads(
 | 
						|
                module, batch_size=batch_size, loss_reduction="sum"
 | 
						|
            )(input, *args, **kwargs).data.sum()
 | 
						|
            actual_res.backward()
 | 
						|
            actual_grads = []
 | 
						|
            for param in module.parameters():
 | 
						|
                self.assertEqual(param.grad_sample.shape[0], batch_size)
 | 
						|
                actual_grads.append(param.grad_sample)
 | 
						|
                del param.grad_sample
 | 
						|
 | 
						|
            input.data.grad = torch.zeros_like(input.data)
 | 
						|
 | 
						|
            # compute the per sample grads with a for loop
 | 
						|
            expected_res = torch.zeros_like(actual_res)
 | 
						|
            expected_grads = []
 | 
						|
            padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence(
 | 
						|
                input, batch_first=True
 | 
						|
            )
 | 
						|
            for i in range(len(seq_sizes)):
 | 
						|
                input_slice = padded_input[i].narrow(0, 0, seq_sizes[i])
 | 
						|
                diff_params = module.parameters()
 | 
						|
                batch_dim = 0 if module.m.batch_first else 1
 | 
						|
                res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum()
 | 
						|
                expected_res += res
 | 
						|
                out_grads = torch.autograd.grad(
 | 
						|
                    res, diff_params, torch.ones_like(res), allow_unused=True
 | 
						|
                )
 | 
						|
                expected_grads.append(out_grads)
 | 
						|
 | 
						|
            expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
 | 
						|
            self.assertEqual(actual_res, expected_res, atol=atol, rtol=rtol)
 | 
						|
            [
 | 
						|
                self.assertEqual(actual, expected, atol=atol, rtol=rtol)
 | 
						|
                for (actual, expected) in zip(actual_grads, expected_grads)
 | 
						|
            ]
 | 
						|
 | 
						|
    @modules(
 | 
						|
        filter(
 | 
						|
            lambda m_info: m_info.module_cls
 | 
						|
            in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
 | 
						|
            module_db,
 | 
						|
        )
 | 
						|
    )
 | 
						|
    @tf32_off()
 | 
						|
    def test_module(self, device, dtype, module_info, training):
 | 
						|
        class RNNWrapper(torch.nn.Module):
 | 
						|
            def __init__(self, m_cons, args, kwargs):
 | 
						|
                super().__init__()
 | 
						|
                self.m = m_cons(*args, **kwargs)
 | 
						|
 | 
						|
            def forward(self, *inps):
 | 
						|
                ret = self.m(*inps)
 | 
						|
                assert isinstance(ret, tuple)
 | 
						|
                return ret[0]
 | 
						|
 | 
						|
        def batch_hidden(h):
 | 
						|
            new_h_shape = [1] * (len(h.shape) + 1)
 | 
						|
            new_h_shape[1] = 2
 | 
						|
            return h.unsqueeze(1).repeat(new_h_shape)
 | 
						|
 | 
						|
        module_cls = module_info.module_cls
 | 
						|
        atol, rtol = (1e-3, 1e-4) if dtype == torch.float32 else (None, None)
 | 
						|
        module_inputs = module_info.module_inputs_func(
 | 
						|
            module_info,
 | 
						|
            device=device,
 | 
						|
            dtype=dtype,
 | 
						|
            requires_grad=True,
 | 
						|
            training=training,
 | 
						|
            with_packed_sequence=True,
 | 
						|
        )
 | 
						|
        for module_input in module_inputs:
 | 
						|
            if module_input.forward_input is None:
 | 
						|
                continue
 | 
						|
            args, kwargs = (
 | 
						|
                module_input.constructor_input.args,
 | 
						|
                module_input.constructor_input.kwargs,
 | 
						|
            )
 | 
						|
            m = RNNWrapper(module_cls, args, kwargs)
 | 
						|
            batch_first = m.m.batch_first
 | 
						|
            m.to(device).to(dtype)
 | 
						|
 | 
						|
            args, kwargs = (
 | 
						|
                module_input.forward_input.args,
 | 
						|
                module_input.forward_input.kwargs,
 | 
						|
            )
 | 
						|
 | 
						|
            # if the RNN tests use unbatched inputs--batch the inputs
 | 
						|
            input = args[0]
 | 
						|
            if isinstance(input, torch.Tensor) and input.dim() == 2:
 | 
						|
                input = input.detach()
 | 
						|
                new_input_shape = [1] * (len(input.shape) + 1)
 | 
						|
                if batch_first:
 | 
						|
                    new_input_shape[0] = 2
 | 
						|
                    input = input.repeat(new_input_shape)
 | 
						|
                else:
 | 
						|
                    new_input_shape[1] = 2
 | 
						|
                    input = input.unsqueeze(1).repeat(new_input_shape)
 | 
						|
 | 
						|
                h = args[1] if len(args) > 1 else None
 | 
						|
                if h is not None:
 | 
						|
                    h = (
 | 
						|
                        batch_hidden(h)
 | 
						|
                        if isinstance(h, torch.Tensor)
 | 
						|
                        else tuple(batch_hidden(hx) for hx in h)
 | 
						|
                    )
 | 
						|
                    args = list(args)
 | 
						|
                    args[1] = h
 | 
						|
 | 
						|
            if isinstance(input, torch.nn.utils.rnn.PackedSequence):
 | 
						|
                self._do_test_rnn_packed_sequence(
 | 
						|
                    m, input, args[1:], kwargs, atol=atol, rtol=rtol
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                self._do_test(
 | 
						|
                    m,
 | 
						|
                    input,
 | 
						|
                    args[1:],
 | 
						|
                    kwargs,
 | 
						|
                    batch_first=batch_first,
 | 
						|
                    atol=atol,
 | 
						|
                    rtol=rtol,
 | 
						|
                )
 | 
						|
 | 
						|
    def test_per_sample_api_failing(self):
 | 
						|
        module = nn.Linear(10, 10)
 | 
						|
        input = torch.randn(64, 10)
 | 
						|
        with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
 | 
						|
            call_for_per_sample_grads("fail")(input)
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError, r"Batch size passed must be None or an integer"
 | 
						|
        ):
 | 
						|
            call_for_per_sample_grads(module, batch_size=6.4)(input)
 | 
						|
        with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
 | 
						|
            call_for_per_sample_grads(module, batch_size=-64)(input)
 | 
						|
        with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
 | 
						|
            loss = call_for_per_sample_grads(module)(input).sum()
 | 
						|
            loss.backward()  # populate grad_sample fields
 | 
						|
            call_for_per_sample_grads(module)(input)
 | 
						|
 | 
						|
        module = nn.Linear(10, 10)  # reset to not have grad_sample fields
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError, r"Expected loss_reduction argument to be sum or mean"
 | 
						|
        ):
 | 
						|
            call_for_per_sample_grads(module, loss_reduction="")(input)
 | 
						|
 | 
						|
    def test_per_sample_api_compute_batch_size(self):
 | 
						|
        class CustomModule(nn.Module):
 | 
						|
            def __init__(self) -> None:
 | 
						|
                super().__init__()
 | 
						|
                self.linear = nn.Linear(5, 5)
 | 
						|
 | 
						|
            def forward(self, input1, input2):
 | 
						|
                return self.linear(input1) + self.linear(input2)
 | 
						|
 | 
						|
        module = CustomModule()
 | 
						|
        input1 = torch.randn(4, 5)
 | 
						|
        input2 = torch.randn(5, 5)
 | 
						|
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            "found at least one input with batch size 4 and one with batch size 5",
 | 
						|
        ):
 | 
						|
            call_for_per_sample_grads(module)(input1, input2)
 | 
						|
 | 
						|
        input2 = torch.randn(4, 5)
 | 
						|
        call_for_per_sample_grads(module)(input1, input2)
 | 
						|
 | 
						|
        module = CustomModule()
 | 
						|
        call_for_per_sample_grads(module)(input1, input2=input2)
 | 
						|
 | 
						|
        module = CustomModule()
 | 
						|
        call_for_per_sample_grads(module)(input1=input1, input2=input2)
 | 
						|
 | 
						|
    def test_per_sample_api_compute_batch_size_not_pytreeable(self):
 | 
						|
        @dataclass
 | 
						|
        class NonPytreeableTuple:
 | 
						|
            elem1: torch.Tensor
 | 
						|
            elem2: torch.Tensor
 | 
						|
 | 
						|
        class CustomModule(nn.Module):
 | 
						|
            def __init__(self) -> None:
 | 
						|
                super().__init__()
 | 
						|
                self.linear = nn.Linear(5, 5)
 | 
						|
 | 
						|
            def forward(self, input1, input2):
 | 
						|
                return self.linear(input1.elem1) + self.linear(input1.elem2)
 | 
						|
 | 
						|
        input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5))
 | 
						|
        model = CustomModule()
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            "ExpandedWeights cannot compute the batch size from the inputs",
 | 
						|
        ):
 | 
						|
            call_for_per_sample_grads(model)(input, "")
 | 
						|
 | 
						|
        # would prefer for it to error because input is not pytree-able but that's hard to detect
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError, "Expected ExpandedWeights to have batch size matching input"
 | 
						|
        ):
 | 
						|
            call_for_per_sample_grads(model)(input, torch.randn(5))
 | 
						|
 | 
						|
        model = CustomModule()  # TODO: functional call bug, sam will fix
 | 
						|
        call_for_per_sample_grads(model)(input, torch.randn(4, 5))
 | 
						|
        model = CustomModule()
 | 
						|
        call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5))
 | 
						|
 | 
						|
 | 
						|
class ContextManagerTests(TestBase):
 | 
						|
    def __init__(self, *args, **kwargs):
 | 
						|
        self.test_cpu = kwargs.get("test_cpu", True)
 | 
						|
        self.test_cuda = kwargs.get("test_cuda", True)
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
 | 
						|
    @property
 | 
						|
    def constructor_args(self):
 | 
						|
        return self._get_arg("constructor_args", False)
 | 
						|
 | 
						|
    def test_context_manager(self, test_case, device):
 | 
						|
        kwargs = {"device": device, "dtype": torch.double}
 | 
						|
        module = self.constructor(*self.constructor_args).to(**kwargs)
 | 
						|
        if "Embedding" in self.get_name():
 | 
						|
            kwargs["dtype"] = torch.long
 | 
						|
        input = self._get_input().to(**kwargs)
 | 
						|
        if len(input.shape) == 0 or input.shape[0] == 0:
 | 
						|
            raise unittest.SkipTest(
 | 
						|
                "Can't get per sample gradients when no batch dim or batch dim is 0"
 | 
						|
            )
 | 
						|
        if self.constructor == torch.nn.Linear and len(input.shape) == 1:
 | 
						|
            raise unittest.SkipTest(
 | 
						|
                "Can't get per sample gradients for input of rank 1"
 | 
						|
            )
 | 
						|
        test_case._do_test(module, input)
 | 
						|
 | 
						|
    def test_context_manager_multiple_inputs(self, test_case, device):
 | 
						|
        module = self.constructor(*self.constructor_args).to(device)
 | 
						|
        input = self._get_input()
 | 
						|
        if len(input.shape) == 0 or input.shape[0] == 0:
 | 
						|
            raise unittest.SkipTest(
 | 
						|
                "Can't get per sample gradients when no batch dim or batch dim is 0"
 | 
						|
            )
 | 
						|
        if self.constructor == torch.nn.Linear and len(input.shape) == 1:
 | 
						|
            raise unittest.SkipTest(
 | 
						|
                "Can't get per sample gradients for input of rank 1"
 | 
						|
            )
 | 
						|
        test_case._do_test_multi_input(module, input)
 | 
						|
 | 
						|
 | 
						|
def filter_supported_tests(t):
 | 
						|
    supported_modules = [
 | 
						|
        "Linear",
 | 
						|
        "Conv1d",
 | 
						|
        "Conv2d",
 | 
						|
        "Conv3d",
 | 
						|
        "Embedding",
 | 
						|
        "LayerNorm",
 | 
						|
        "GroupNorm",
 | 
						|
        "InstanceNorm",
 | 
						|
    ]
 | 
						|
    if "module_name" in t and t["module_name"] in supported_modules:
 | 
						|
        return True
 | 
						|
 | 
						|
 | 
						|
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
 | 
						|
# These currently use the legacy nn tests
 | 
						|
supported_tests = [
 | 
						|
    t for t in module_tests + get_new_module_tests() if filter_supported_tests(t)
 | 
						|
]
 | 
						|
for test_param in supported_tests:
 | 
						|
    if "constructor" not in test_param:
 | 
						|
        name = test_param.pop("module_name")
 | 
						|
        test_param["constructor"] = getattr(nn, name)
 | 
						|
    decorator = test_param.pop("decorator", lambda test: test)
 | 
						|
    test = ContextManagerTests(**test_param)
 | 
						|
    test_name = test.get_name()
 | 
						|
    if hasattr(TestExpandedWeightModule, test_name):
 | 
						|
        raise RuntimeError("Found two tests with the same name: " + test_name)
 | 
						|
    test_name_multi_input = test.get_name() + "_multiple_inputs"
 | 
						|
    if hasattr(TestExpandedWeightModule, test_name_multi_input):
 | 
						|
        raise RuntimeError("Found two tests with the same name: " + test_name)
 | 
						|
    if test.test_cpu:
 | 
						|
        setattr(
 | 
						|
            TestExpandedWeightModule,
 | 
						|
            test_name,
 | 
						|
            decorator(lambda self, test=test: test.test_context_manager(self, "cpu")),
 | 
						|
        )
 | 
						|
        setattr(
 | 
						|
            TestExpandedWeightModule,
 | 
						|
            test_name_multi_input,
 | 
						|
            decorator(
 | 
						|
                lambda self, test=test: test.test_context_manager_multiple_inputs(
 | 
						|
                    self, "cpu"
 | 
						|
                )
 | 
						|
            ),
 | 
						|
        )
 | 
						|
    if TEST_CUDA and test.test_cuda:
 | 
						|
        # since this checks derivatives, only use double for precision
 | 
						|
        setattr(
 | 
						|
            TestExpandedWeightModule,
 | 
						|
            test_name + "_cuda_double",
 | 
						|
            decorator(lambda self, test=test: test.test_context_manager(self, "cuda")),
 | 
						|
        )
 | 
						|
 | 
						|
# ------------- HELPER FUNCTIONS -----------------
 | 
						|
 | 
						|
 | 
						|
def run_op(op, input, *args, **kwargs):
 | 
						|
    r"""
 | 
						|
    OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative
 | 
						|
    of the weight, not the input, which can't be differentiable since its dtype is int. Calls op,
 | 
						|
    using the special ordering that Embedding's OpInfo expects for that case.
 | 
						|
    """
 | 
						|
    if op.name == "nn.functional.embedding":
 | 
						|
        return op(args[0], input, **kwargs)
 | 
						|
    else:
 | 
						|
        return op(input, *args, **kwargs)
 | 
						|
 | 
						|
 | 
						|
def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"):
 | 
						|
    def expanded_weight_or_clone(arg):
 | 
						|
        if is_diff_tensor(arg):
 | 
						|
            return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
 | 
						|
        return clone_if_tensor(arg)
 | 
						|
 | 
						|
    ew_input = clone_if_tensor(sample_input.input)
 | 
						|
    ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
 | 
						|
    ew_kwargs = {
 | 
						|
        name: expanded_weight_or_clone(arg)
 | 
						|
        for (name, arg) in sample_input.kwargs.items()
 | 
						|
    }
 | 
						|
    return ew_input, ew_args, ew_kwargs
 | 
						|
 | 
						|
 | 
						|
def supported_inputs(op, sample_inputs, supported_inputs=True):
 | 
						|
    r"""
 | 
						|
    ExpandedWeights currently does not support some use cases when there's no batch dimension or
 | 
						|
    operations that would cause inter-batch operations. Removes all of the cases it cannot deal with
 | 
						|
    """
 | 
						|
 | 
						|
    def filter_fn(input):
 | 
						|
        convolutions = [
 | 
						|
            "nn.functional.conv1d",
 | 
						|
            "nn.functional.conv2d",
 | 
						|
            "nn.functional.conv3d",
 | 
						|
        ]
 | 
						|
        batched_input_size = dict(zip(convolutions, [3, 4, 5]))
 | 
						|
        if op.name == "nn.functional.linear":
 | 
						|
            is_supported_input = (
 | 
						|
                input.input.dim() > 1
 | 
						|
            )  # input of rank 1 means no batch dim
 | 
						|
        elif op.name == "nn.functional.layer_norm":
 | 
						|
            normalized_shape = input.args[0]
 | 
						|
            is_supported_input = (
 | 
						|
                input.input.shape != normalized_shape
 | 
						|
            )  # would cause inter-batch operations
 | 
						|
        elif op.name in convolutions:
 | 
						|
            # currently can't deal with padding computation on Python level
 | 
						|
            is_supported_input = input.input.dim() == batched_input_size[op.name]
 | 
						|
        elif op.name == "nn.functional.embedding":
 | 
						|
            idx = input.args[0]
 | 
						|
            is_supported_input = len(idx.shape) > 1  # there's no batch size
 | 
						|
        else:
 | 
						|
            is_supported_input = True
 | 
						|
        is_supported_input = (
 | 
						|
            is_supported_input and input.input.shape[0] > 0
 | 
						|
        )  # 0 is not a valid batch size
 | 
						|
        return is_supported_input if supported_inputs else not is_supported_input
 | 
						|
 | 
						|
    return [input for input in sample_inputs if filter_fn(input)]
 | 
						|
 | 
						|
 | 
						|
def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs):
 | 
						|
    # get per sample grads by getting derivative for each input in a for loop
 | 
						|
    per_sample_grad = []
 | 
						|
    for i in range(batch_size):
 | 
						|
        per_sample_input = input[i]
 | 
						|
        result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs))
 | 
						|
        diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
 | 
						|
        diff_input_list = [
 | 
						|
            i
 | 
						|
            for i in diff_input_list
 | 
						|
            if isinstance(i, torch.Tensor) and i.requires_grad
 | 
						|
        ]
 | 
						|
        per_sample_grad.append(
 | 
						|
            torch.autograd.grad(
 | 
						|
                result, diff_input_list, torch.ones_like(result), allow_unused=True
 | 
						|
            )
 | 
						|
        )
 | 
						|
    if len(per_sample_grad) == batch_size:
 | 
						|
        per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad))
 | 
						|
    return per_sample_grad
 | 
						|
 | 
						|
 | 
						|
def is_diff_tensor(t):
 | 
						|
    return isinstance(t, ExpandedWeight) or (
 | 
						|
        isinstance(t, torch.Tensor) and t.requires_grad
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def clone_if_tensor(t):
 | 
						|
    if isinstance(t, torch.Tensor):
 | 
						|
        res = torch.clone(t).detach()
 | 
						|
        res.requires_grad_(t.requires_grad)
 | 
						|
        return res
 | 
						|
    else:
 | 
						|
        return t
 | 
						|
 | 
						|
 | 
						|
instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
 | 
						|
instantiate_device_type_tests(TestExpandedWeightFunctional, globals())
 | 
						|
instantiate_device_type_tests(TestExpandedWeightModule, globals())
 | 
						|
if __name__ == "__main__":
 | 
						|
    run_tests()
 |