# mypy: ignore-errors from abc import abstractmethod import tempfile import unittest from copy import deepcopy from functools import reduce, partial from itertools import product from operator import mul import torch import torch.cuda import torch.nn as nn import torch.nn.functional as F from torch.nn import _reduction as _Reduction from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors from torch.autograd import Variable from torch.types import _TensorOrTensors import torch.backends.cudnn from typing import Dict, Callable, Tuple, List, Sequence, Union, Any TemporaryFile = tempfile.TemporaryFile PRECISION = 1e-5 def get_reduction(m): result = getattr(m, 'reduction', None) if result is None: result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False) assert result is not None return result def get_weight(m): result = getattr(m, 'weight', None) if result is not None: return result return getattr(m, 'weights', None) # NOTE [How to check NN module / functional API parity between Python and C++ frontends] # # The way to check API parity is to add parity tests for the NN module / functional of interest. # Here are the detailed steps: # # For NN module: # 1. Make sure you already have a test dict with the module configuration you want to test. # 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching # the Python module constructor arguments. For example, if in the test dict we pass # `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)` # as the corresponding C++ constructor argument to `torch::nn::Linear`. # 3. If in the process of performing the above step you referenced any variables # in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry # to the test dict to make sure that those variables are populated with the right Python values. # For example, if the Python constructor call is # `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`, # the corresponding C++ constructor argument is # `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`, # and the `cpp_var_map` entry must be # `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples` # used in the C++ constructor argument with the Python tensor value `random_samples`. # # For NN functional: # 1. Make sure you already have a test dict with the functional configuration you want to test. # 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`, # then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python # functional optional arguments. For example, if the test dict's `constructor` entry is # `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`, # then the `cpp_options_args` entry should be # "F::InterpolateFuncOptions().size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)". # 3. Otherwise, if the test dict's `constructor` entry looks like # `wrap_functional(lambda i: F.some_functional_name(...))`, # then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python # functional function call. For example, if the test dict's `constructor` entry is # `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`, # then the `cpp_function_call` entry should be # "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))". # 4. If in the process of performing the above two steps you referenced any variables # in the `cpp_options_args` or `cpp_function_call` entry, you must # add `cpp_var_map` entry to the test dict to make sure that those variables # are populated with the right Python values. For example, if the test dict's `constructor` entry is # `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`, # then the `cpp_function_call` entry should be # "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))". # Notice that there are two variables `i` and `t` that need to have their values provided, # and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`. # (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value # and the C++ parity test mechanism will populate `i` with the Python input value correctly.) # # There are also a few optional flags in the test dict to control the C++ parity test behavior: # # - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True. # - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True. module_tests = [ dict( module_name='Linear', constructor_args=(10, 8), cpp_constructor_args='torch::nn::LinearOptions(10, 8)', input_size=(4, 10), reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8), with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Linear', constructor_args=(10, 8, False), cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)', input_size=(4, 10), desc='no_bias', reference_fn=lambda i, p, _: torch.mm(i, p[0].t()), with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='RReLU', input_size=(1, 2, 2), test_cuda=False, default_dtype=torch.double, ), dict( module_name='RReLU', constructor_args=(0.1, 0.9), cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)', input_size=(4, 4, 5), desc='with_up_down', test_cuda=False, default_dtype=torch.double, ), dict( module_name='Flatten', input_size=(2, 3, 4, 5), reference_fn=lambda i, *_: torch.flatten(i, 1), default_dtype=torch.double, ), # TODO: reference function dict( module_name='CrossMapLRN2d', constructor_args=(5, 5e-3, 1e-3, 2), cpp_constructor_args='torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)', input_size=(2, 3, 6, 6), check_gradgrad=False, # TODO(#50743): Figure out the error. "RuntimeError: Unrecognized tensor type ID: Batched" check_batched_grad=False, default_dtype=torch.double, ), ] # Generates rand tensor with non-equal values. This ensures that duplicate # values won't be causing test failure for modules like MaxPooling. # size should be small, otherwise randperm fails / long overflows. def _rand_tensor_non_equal(*size): total = reduce(mul, size, 1) return torch.randperm(total).view(*size).double() def wrap_functional(fn, **kwargs): class FunctionalModule(nn.Module): def forward(self, *args): return fn(*args, **kwargs) return FunctionalModule def poissonnllloss_no_reduce_test(): t = torch.randn(10, 10) return dict( fullname='PoissonNLLLoss_no_reduce', constructor=wrap_functional( lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')), cpp_function_call='F::poisson_nll_loss(' 'i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))', input_fn=lambda: torch.rand(10, 10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: i.exp() - t.mul(i), pickle=False, default_dtype=torch.double) def bceloss_no_reduce_test(): t = Variable(torch.randn(15, 10).gt(0).to(torch.double)) return dict( fullname='BCELoss_no_reduce', constructor=wrap_functional( lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')), cpp_function_call='F::binary_cross_entropy(' 'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))', input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()), pickle=False, precision=7e-4, default_dtype=torch.double) def bceloss_no_reduce_scalar_test(): t = torch.randn(()).gt(0).to(torch.double) return dict( fullname='BCELoss_no_reduce_scalar', constructor=wrap_functional( lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')), cpp_function_call='F::binary_cross_entropy(' 'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))', input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()), pickle=False, default_dtype=torch.double) def bceloss_weights_no_reduce_test(): t = Variable(torch.randn(15, 10, dtype=torch.double).gt(0).to(torch.double)) weights = torch.rand(10, dtype=torch.double) return dict( fullname='BCELoss_weights_no_reduce', constructor=wrap_functional( lambda i: F.binary_cross_entropy(i, t.type_as(i), weight=weights.type_as(i), reduction='none')), cpp_function_call='F::binary_cross_entropy(' 'i, t.to(i.options()), ' 'F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))', input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, pickle=False, precision=3e-4, default_dtype=torch.double, ) def bceloss_weights_no_reduce_scalar_test(): t = torch.randn(()).gt(0).to(torch.double) weights = torch.rand((), dtype=torch.double) return dict( fullname='BCELoss_weights_no_reduce_scalar', constructor=wrap_functional( lambda i: F.binary_cross_entropy(i, t.type_as(i), weight=weights.type_as(i), reduction='none')), cpp_function_call='''F::binary_cross_entropy( i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''', cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, pickle=False, default_dtype=torch.double, ) def bce_with_logistic_legacy_enum_test(): t = Variable(torch.randn(15, 10).gt(0).to(torch.double)) sigmoid = nn.Sigmoid() return dict( fullname='BCEWithLogitsLoss_legacy_enum', constructor=wrap_functional( lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)), cpp_function_call='''F::binary_cross_entropy_with_logits( i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), check_gradgrad=False, pickle=False, default_dtype=torch.double, ) def bce_with_logistic_no_reduce_test(): t = Variable(torch.randn(15, 10).gt(0).to(torch.double)) sigmoid = nn.Sigmoid() return dict( fullname='BCEWithLogitsLoss_no_reduce', constructor=wrap_functional( lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')), cpp_function_call='''F::binary_cross_entropy_with_logits( i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), check_gradgrad=False, pickle=False, default_dtype=torch.double, ) def bce_with_logistic_no_reduce_scalar_test(): t = torch.randn(()).gt(0).to(torch.double) sigmoid = nn.Sigmoid() return dict( fullname='BCEWithLogitsLoss_no_reduce_scalar', constructor=wrap_functional( lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')), cpp_function_call='''F::binary_cross_entropy_with_logits( i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()), check_gradgrad=False, pickle=False, default_dtype=torch.double, ) def kldivloss_with_target_no_reduce_test(): t = torch.rand(10, 10, dtype=torch.double) return dict( fullname='KLDivLoss_with_target_no_reduce', constructor=wrap_functional( lambda i: F.kl_div(i, t.type_as(i), reduction='none')), cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))', input_fn=lambda: torch.rand(10, 10).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def kldivloss_no_reduce_test(): t = torch.rand(10, 10, dtype=torch.double) return dict( fullname='KLDivLoss_no_reduce', constructor=wrap_functional( lambda i: F.kl_div(i, t.type_as(i), reduction='none')), cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))', input_fn=lambda: torch.rand(10, 10).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False, default_dtype=torch.double, ) def kldivloss_no_reduce_scalar_test(): t = torch.rand((), dtype=torch.double) return dict( fullname='KLDivLoss_no_reduce_scalar', constructor=wrap_functional( lambda i: F.kl_div(i, t.type_as(i), reduction='none')), cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))', input_fn=lambda: torch.rand(()).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def kldivloss_with_log_target_no_reduce_test(): t = torch.rand(10, 10, dtype=torch.double).log() return dict( fullname='KLDivLoss_with_log_target_no_reduce', constructor=wrap_functional( lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)), cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', input_fn=lambda: torch.rand(10, 10).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def kldivloss_no_reduce_log_target_test(): t = torch.rand(10, 10, dtype=torch.double).log() return dict( fullname='KLDivLoss_no_reduce_log_target', constructor=wrap_functional( lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)), cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', input_fn=lambda: torch.rand(10, 10).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False, default_dtype=torch.double, ) def kldivloss_no_reduce_scalar_log_target_test(): t = torch.rand((), dtype=torch.double).log() return dict( fullname='KLDivLoss_no_reduce_scalar_log_target', constructor=wrap_functional( lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)), cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))', input_fn=lambda: torch.rand(()).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def l1loss_no_reduce_test(): t = torch.randn(2, 3, 4, dtype=torch.double) return dict( fullname='L1Loss_no_reduce', constructor=wrap_functional( lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))', input_fn=lambda: torch.randn(2, 3, 4), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: (i - t.type_as(i)).abs(), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def l1loss_no_reduce_complex_test(): t = torch.randn(2, 3, 4, dtype=torch.cdouble) return dict( fullname='L1Loss_no_reduce_complex', constructor=wrap_functional( lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))', input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: (i - t.type_as(i)).abs(), supports_forward_ad=True, pickle=False) def l1loss_no_reduce_scalar_test(): t = torch.randn((), dtype=torch.double) return dict( fullname='L1Loss_no_reduce_scalar', constructor=wrap_functional( lambda i: F.l1_loss(i, t.type_as(i), reduction='none')), cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))', input_fn=lambda: torch.randn(()), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: (i - t.type_as(i)).abs(), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def mseloss_no_reduce_test(): input_size = (2, 3, 4, 5) target = torch.randn(*input_size, dtype=torch.double) return dict( fullname='MSELoss_no_reduce', constructor=wrap_functional( lambda i: F.mse_loss(i, target.type_as(i), reduction='none')), cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))', input_size=input_size, cpp_var_map={'i': '_get_input()', 'target': target}, reference_fn=lambda i, *_: (i - target).pow(2), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def mseloss_no_reduce_scalar_test(): input_size = () target = torch.randn(input_size, dtype=torch.double) return dict( fullname='MSELoss_no_reduce_scalar', constructor=wrap_functional( lambda i: F.mse_loss(i, target.type_as(i), reduction='none')), cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))', input_size=input_size, cpp_var_map={'i': '_get_input()', 'target': target}, reference_fn=lambda i, *_: (i - target).pow(2), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def nllloss_no_reduce_test(): t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) kwargs = {'reduction': 'none'} return dict( fullname='NLLLoss_no_reduce', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.rand(15, 10).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs), pickle=False, default_dtype=torch.double) def nllloss_no_reduce_ignore_index_test(): t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'} return dict( fullname='NLLLoss_no_reduce_ignore_index', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), reduction=str(kwargs['reduction']))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''', input_fn=lambda: torch.rand(15, 10).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs), pickle=False, default_dtype=torch.double) def nllloss_no_reduce_weights_test(): t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) weight = torch.rand(10) def kwargs(i): return {'weight': weight.type_as(i), 'reduction': 'none'} return dict( fullname='NLLLoss_no_reduce_weights', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''', input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, reference_fn=lambda i, *_: loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)), pickle=False, default_dtype=torch.double) def nllloss_no_reduce_weights_ignore_index_test(): t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) weight = torch.rand(10) def kwargs(i): return {'weight': weight.type_as(i), 'reduction': 'none', 'ignore_index': 2} return dict( fullname='NLLLoss_no_reduce_weights_ignore_index', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))''', input_fn=lambda: torch.rand(15, 10).add(1e-2).log(), cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, reference_fn=lambda i, *_: loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)), pickle=False, default_dtype=torch.double) def nllloss_no_reduce_weights_ignore_index_neg_test(): t = Variable(torch.empty(15).uniform_().mul(10).floor().long()) weight = torch.rand(10) def kwargs(i): return {'weight': weight.type_as(i), 'reduction': 'none', 'ignore_index': -1} return dict( fullname='NLLLoss_no_reduce_weights_ignore_index_neg', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))''', input=torch.rand(15, 10, dtype=torch.double).add(1e-2).log(), cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, reference_fn=lambda i, *_: loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)), pickle=False, default_dtype=torch.double) def nllloss2d_no_reduce_test(): t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) kwargs = {'reduction': 'none'} return dict( fullname='NLLLoss2d_no_reduce', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.rand(2, 3, 5, 5).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), pickle=False, default_dtype=torch.double) def nllloss2d_no_reduce_ignore_index_test(): t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'} return dict( fullname='NLLLoss2d_no_reduce_ignore_index', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), reduction=str(kwargs['reduction']))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''', input_fn=lambda: torch.rand(2, 3, 5, 5).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), pickle=False, default_dtype=torch.double) def nllloss2d_no_reduce_weights_test(): t = Variable(torch.rand(2, 5, 5).mul(3).floor().long()) weight = torch.rand(3) def kwargs(i): return {'weight': weight.type_as(i), 'reduction': 'none'} return dict( fullname='NLLLoss2d_no_reduce_weights', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''', input_fn=lambda: torch.rand(2, 3, 5, 5).log(), cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, reference_fn=lambda i, *_: loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)), pickle=False, default_dtype=torch.double) def nlllossNd_no_reduce_test(): t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) kwargs = {'reduction': 'none'} return dict( fullname='NLLLossNd_no_reduce', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), pickle=False, default_dtype=torch.double) def nlllossNd_no_reduce_ignore_index_test(): t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'} return dict( fullname='NLLLossNd_no_reduce_ignore_index', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']), reduction=str(kwargs['reduction']))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''', input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs), pickle=False, default_dtype=torch.double) def nlllossNd_no_reduce_weights_test(): t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long()) weight = torch.rand(3) def kwargs(i): return {'weight': weight.type_as(i), 'reduction': 'none'} return dict( fullname='NLLLossNd_no_reduce_weights', constructor=wrap_functional( lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))), cpp_function_call='''F::nll_loss( i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''', input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(), cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight}, reference_fn=lambda i, *_: loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)), pickle=False, default_dtype=torch.double) def smoothl1loss_no_reduce_test(): t = torch.randn(2, 3, 4, dtype=torch.double) return dict( fullname='SmoothL1Loss_no_reduce', constructor=wrap_functional( lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')), cpp_function_call='''F::smooth_l1_loss( i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(2, 3, 4), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def smoothl1loss_no_reduce_scalar_test(): t = torch.randn((), dtype=torch.double) return dict( fullname='SmoothL1Loss_no_reduce_scalar', constructor=wrap_functional( lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')), cpp_function_call='''F::smooth_l1_loss( i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(()), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def smoothl1loss_beta_test(): t = torch.randn(2, 3, 4, dtype=torch.double) return dict( fullname='SmoothL1Loss_beta', constructor=wrap_functional( lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)), cpp_function_call='''F::smooth_l1_loss( i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''', input_fn=lambda: torch.randn(2, 3, 4), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def smoothl1loss_zero_beta_test(): t = torch.randn(2, 3, 4, dtype=torch.double) return dict( fullname='SmoothL1Loss_zero_beta', constructor=wrap_functional( lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)), cpp_function_call='''F::smooth_l1_loss( i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''', input_fn=lambda: torch.randn(2, 3, 4), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def huberloss_delta_test(): t = torch.randn(2, 3, 4) return dict( fullname='HuberLoss_delta', constructor=wrap_functional( lambda i: F.huber_loss(i, t.type_as(i), reduction='none', delta=0.5)), cpp_function_call='''F::huber_loss( i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))''', input_fn=lambda: torch.randn(2, 3, 4), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['HuberLoss'](i, t.type_as(i), reduction='none', delta=0.5), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def multilabelmarginloss_0d_no_reduce_test(): t = torch.zeros(()).long() return dict( fullname='MultiLabelMarginLoss_0d_no_reduce', constructor=wrap_functional( lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), cpp_function_call='''F::multilabel_margin_loss( i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(()), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), check_sum_reduction=True, check_gradgrad=False, pickle=False) def multilabelmarginloss_1d_no_reduce_test(): t = Variable(torch.rand(10).mul(10).floor().long()) return dict( fullname='MultiLabelMarginLoss_1d_no_reduce', constructor=wrap_functional( lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), cpp_function_call='''F::multilabel_margin_loss( i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), check_sum_reduction=True, check_gradgrad=False, pickle=False, default_dtype=torch.double) def multilabelmarginloss_index_neg_test(): t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1)) return dict( fullname='MultiLabelMarginLoss_index_neg', constructor=wrap_functional( lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), cpp_function_call='''F::multilabel_margin_loss( i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(5, 10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), check_sum_reduction=True, check_gradgrad=False, pickle=False, default_dtype=torch.double) def multilabelmarginloss_no_reduce_test(): t = Variable(torch.rand(5, 10).mul(10).floor().long()) return dict( fullname='MultiLabelMarginLoss_no_reduce', constructor=wrap_functional( lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')), cpp_function_call='''F::multilabel_margin_loss( i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(5, 10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), check_sum_reduction=True, check_gradgrad=False, pickle=False, default_dtype=torch.double) def hingeembeddingloss_no_reduce_test(): t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1)) return dict( fullname='HingeEmbeddingLoss_no_reduce', constructor=wrap_functional( lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')), cpp_function_call='''F::hinge_embedding_loss( i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'), check_sum_reduction=True, pickle=False, default_dtype=torch.double) def hingeembeddingloss_margin_no_reduce_test(): t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1)) return dict( fullname='HingeEmbeddingLoss_margin_no_reduce', constructor=wrap_functional( lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')), cpp_function_call='''F::hinge_embedding_loss( i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))''', input_fn=lambda: torch.randn(10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'), check_sum_reduction=True, pickle=False, default_dtype=torch.double) def softmarginloss_no_reduce_test(): t = torch.randn(5, 5, dtype=torch.double) return dict( fullname='SoftMarginLoss_no_reduce', constructor=wrap_functional( lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')), cpp_function_call='''F::soft_margin_loss( i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(5, 5), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'), supports_forward_ad=True, pickle=False, default_dtype=torch.double) def multilabelsoftmarginloss_no_reduce_test(): t = torch.rand(5, 10).mul(2).floor() return dict( fullname='MultiLabelSoftMarginLoss_no_reduce', constructor=wrap_functional( lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')), cpp_function_call='''F::multilabel_soft_margin_loss( i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(5, 10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1), check_gradgrad=False, pickle=False, default_dtype=torch.double) def multilabelsoftmarginloss_weights_no_reduce_test(): t = torch.rand(5, 10).mul(2).floor() weights = torch.rand(10) return dict( fullname='MultiLabelSoftMarginLoss_weights_no_reduce', constructor=wrap_functional( lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), weight=weights.type_as(i), reduction='none')), cpp_function_call='''F::multilabel_soft_margin_loss( i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''', input_fn=lambda: torch.randn(5, 10), cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, reference_fn=lambda i, *_: (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1), check_sum_reduction=True, check_gradgrad=False, pickle=False, default_dtype=torch.double) def multimarginloss_no_reduce_test(): t = torch.rand(5).mul(8).floor().long() return dict( fullname='MultiMarginLoss_no_reduce', constructor=wrap_functional( lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), cpp_function_call='''F::multi_margin_loss( i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(5, 10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), check_sum_reduction=True, check_gradgrad=False, pickle=False, default_dtype=torch.double) def multimarginloss_1d_no_reduce_test(): t = torch.rand(1).mul(8).floor().long() return dict( fullname='MultiMarginLoss_1d_no_reduce', constructor=wrap_functional( lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), cpp_function_call='''F::multi_margin_loss( i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), check_sum_reduction=True, check_gradgrad=False, pickle=False, default_dtype=torch.double) def multimarginloss_1d_input_0d_target_no_reduce_test(): t = torch.rand(()).mul(8).floor().long() return dict( fullname='multimarginloss_1d_input_0d_target_no_reduce', constructor=wrap_functional( lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')), cpp_function_call='''F::multi_margin_loss( i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''', input_fn=lambda: torch.randn(10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'), check_sum_reduction=True, check_gradgrad=False, pickle=False, default_dtype=torch.double) def multimarginloss_p_no_reduce_test(): t = torch.rand(5).mul(8).floor().long() return dict( fullname='MultiMarginLoss_p_no_reduce', constructor=wrap_functional( lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')), cpp_function_call='''F::multi_margin_loss( i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))''', input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'), check_sum_reduction=True, check_gradgrad=False, pickle=False, default_dtype=torch.double) def multimarginloss_margin_no_reduce_test(): t = torch.rand(5).mul(8).floor().long() return dict( fullname='MultiMarginLoss_margin_no_reduce', constructor=wrap_functional( lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')), cpp_function_call='''F::multi_margin_loss( i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))''', input_fn=lambda: torch.randn(5, 10), cpp_var_map={'i': '_get_input()', 't': t}, reference_fn=lambda i, *_: loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), margin=0.5, reduction='none'), check_sum_reduction=True, check_gradgrad=False, pickle=False, default_dtype=torch.double) def multimarginloss_weights_no_reduce_test(): t = torch.rand(5).mul(8).floor().long() weights = torch.rand(10, dtype=torch.double) return dict( fullname='MultiMarginLoss_weights_no_reduce', constructor=wrap_functional( lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i), reduction='none')), cpp_function_call='''F::multi_margin_loss( i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''', input_fn=lambda: torch.randn(5, 10), cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, reference_fn=lambda i, *_: loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), weight=weights, reduction='none'), check_sum_reduction=True, check_gradgrad=False, pickle=False, default_dtype=torch.double) def single_batch_reference_fn(input, parameters, module): """Reference function for modules supporting no batch dimensions. The module is passed the input and target in batched form with a single item. The output is squeezed to compare with the no-batch input. """ def unsqueeze_inp(inp): if isinstance(inp, (list, tuple)): return [t.unsqueeze(0) for t in inp] return inp.unsqueeze(0) single_batch_input = unsqueeze_inp(input) single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input with freeze_rng_state(): return module(*single_batch_input).squeeze(0) def get_new_module_tests(): new_module_tests = [ poissonnllloss_no_reduce_test(), bceloss_no_reduce_test(), bceloss_weights_no_reduce_test(), bce_with_logistic_legacy_enum_test(), bce_with_logistic_no_reduce_test(), bceloss_no_reduce_scalar_test(), bceloss_weights_no_reduce_scalar_test(), bce_with_logistic_no_reduce_scalar_test(), kldivloss_with_target_no_reduce_test(), kldivloss_no_reduce_test(), kldivloss_no_reduce_scalar_test(), kldivloss_with_log_target_no_reduce_test(), kldivloss_no_reduce_log_target_test(), kldivloss_no_reduce_scalar_log_target_test(), l1loss_no_reduce_test(), l1loss_no_reduce_complex_test(), l1loss_no_reduce_scalar_test(), mseloss_no_reduce_test(), mseloss_no_reduce_scalar_test(), nllloss_no_reduce_test(), nllloss_no_reduce_ignore_index_test(), nllloss_no_reduce_weights_test(), nllloss_no_reduce_weights_ignore_index_test(), nllloss_no_reduce_weights_ignore_index_neg_test(), nllloss2d_no_reduce_test(), nllloss2d_no_reduce_weights_test(), nllloss2d_no_reduce_ignore_index_test(), nlllossNd_no_reduce_test(), nlllossNd_no_reduce_weights_test(), nlllossNd_no_reduce_ignore_index_test(), smoothl1loss_no_reduce_test(), smoothl1loss_no_reduce_scalar_test(), smoothl1loss_beta_test(), smoothl1loss_zero_beta_test(), huberloss_delta_test(), multilabelmarginloss_0d_no_reduce_test(), multilabelmarginloss_1d_no_reduce_test(), multilabelmarginloss_index_neg_test(), multilabelmarginloss_no_reduce_test(), hingeembeddingloss_no_reduce_test(), hingeembeddingloss_margin_no_reduce_test(), softmarginloss_no_reduce_test(), multilabelsoftmarginloss_no_reduce_test(), multilabelsoftmarginloss_weights_no_reduce_test(), multimarginloss_no_reduce_test(), multimarginloss_1d_no_reduce_test(), multimarginloss_1d_input_0d_target_no_reduce_test(), multimarginloss_p_no_reduce_test(), multimarginloss_margin_no_reduce_test(), multimarginloss_weights_no_reduce_test(), dict( module_name='Conv1d', constructor_args=(4, 5, 3), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', input_size=(2, 4, 10), cudnn=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv1d', constructor_args=(4, 5, 3, 2), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)', input_size=(2, 4, 10), cudnn=True, desc='stride', with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv1d', constructor_args=(4, 5, 3, 1, 1), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)', input_size=(2, 4, 10), cudnn=True, desc='pad1', with_tf32=True, tf32_precision=0.01, default_dtype=torch.double, ), dict( module_name='Conv1d', constructor_args=(4, 5, 5, 1, 2), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)', input_size=(2, 4, 10), cudnn=True, desc='pad2', with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv1d', constructor_args=(4, 4, 3, 1, 1), cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)', input_size=(1, 4, 1), cudnn=True, desc='pad1size1', with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv1d', constructor_args=(4, 4, 5, 1, 2), cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)', input_size=(1, 4, 1), cudnn=True, desc='pad2size1', with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv1d', constructor_args=(4, 5, 3), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)', input_size=(0, 4, 10), cudnn=True, desc='zero_batch', with_tf32=True, tf32_precision=0.005, ), dict( fullname='Conv1d_dilated', constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)', input_size=(2, 4, 10), with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv1d_groups', constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2), cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)', input_size=(2, 4, 6), cudnn=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv1d_pad_valid', constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)', input_size=(2, 4, 10), cudnn=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv1d_pad_same', constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)', input_size=(2, 4, 10), cudnn=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv1d_pad_same2', constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)', input_size=(2, 4, 10), cudnn=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv1d_pad_same_dilated', constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2), cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)', input_size=(2, 4, 10), cudnn=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='ConvTranspose1d', constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)), cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)', cudnn=True, input_size=(1, 3, 7), with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='ConvTranspose1d', constructor_args=(3, 4, 3, 2, 1, 1, 1, False), cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3) .stride(2).padding(1).output_padding(1).groups(1).bias(false)''', input_size=(1, 3, 6), cudnn=True, desc='no_bias', with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='ConvTranspose1d', constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2), cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3) .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''', input_size=(1, 3, 6), cudnn=True, desc='dilated', with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='ConvTranspose1d_groups', constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2), cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3) .stride(3).padding(1).output_padding(1).groups(2)''', cudnn=True, input_size=(2, 4, 7), with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv2d', constructor_args=(3, 4, (3, 2)), cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})', input_size=(2, 3, 7, 5), cudnn=True, check_with_long_tensor=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv2d', constructor_args=(3, 4, (3, 3), (2, 2)), cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})', input_size=(2, 3, 6, 6), cudnn=True, desc='strided', check_with_long_tensor=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv2d', constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)), cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})', input_size=(2, 3, 6, 6), cudnn=True, desc='padding', check_with_long_tensor=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv2d', constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)), cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})', input_size=(2, 3, 8, 8), cudnn=True, desc='dilated', check_with_long_tensor=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv2d', constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False), cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2}) .stride(1).padding(0).dilation(1).groups(1).bias(false)''', input_size=(2, 3, 6, 5), cudnn=True, desc='no_bias', check_with_long_tensor=True, with_tf32=True, tf32_precision=0.015, default_dtype=torch.double, ), dict( module_name='Conv2d', constructor_args=(3, 4, (3, 2)), cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})', input_size=(0, 3, 7, 5), cudnn=True, desc='zero_batch', check_with_long_tensor=True, with_tf32=True, ), dict( fullname='Conv2d_groups', constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2), cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', input_size=(2, 4, 6, 5), cudnn=True, check_with_long_tensor=True, with_tf32=True, tf32_precision=0.015, default_dtype=torch.double, ), dict( fullname='Conv2d_groups_thnn', constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2), cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)', input_size=(2, 4, 6, 5), check_with_long_tensor=True, with_tf32=True, tf32_precision=0.015, default_dtype=torch.double, ), dict( fullname='Conv2d_pad_valid', constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"), cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)', input_size=(2, 2, 6, 5), cudnn=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv2d_pad_same', constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"), cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)', input_size=(2, 2, 6, 5), cudnn=True, with_tf32=True, tf32_precision=0.01, default_dtype=torch.double, ), dict( fullname='Conv2d_pad_same_dilated', constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2), cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)', input_size=(2, 2, 6, 5), cudnn=True, with_tf32=True, tf32_precision=0.01, default_dtype=torch.double, ), dict( module_name='ConvTranspose2d', constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)), cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) .stride({3, 2}).padding(1).output_padding({1, 1})''', cudnn=True, input_size=(1, 3, 7, 6), check_with_long_tensor=True, with_tf32=True, tf32_precision=0.01, default_dtype=torch.double, ), dict( module_name='ConvTranspose2d', constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)), cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) .stride({2, 3}) .padding(1) .output_padding({1, 1}) .groups(1) .bias(false) .dilation({2, 2})''', input_size=(1, 3, 6, 7), cudnn=True, desc='dilated', check_with_long_tensor=True, with_tf32=True, tf32_precision=0.01, default_dtype=torch.double, ), dict( module_name='ConvTranspose2d', constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False), cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3) .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''', input_size=(1, 3, 6, 7), cudnn=True, desc='no_bias', check_with_long_tensor=True, with_tf32=True, tf32_precision=0.01, default_dtype=torch.double, ), dict( fullname='ConvTranspose2d_groups', constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2), cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)', input_size=(1, 2, 4, 5), cudnn=True, check_with_long_tensor=True, with_tf32=True, tf32_precision=0.01, default_dtype=torch.double, ), dict( fullname='Conv2d_depthwise', constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4), cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)', input_size=(2, 4, 6, 6), with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv2d_depthwise_with_multiplier', constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4), cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)', input_size=(2, 4, 6, 6), with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv2d_depthwise_strided', constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4), cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)', input_size=(2, 4, 6, 6), with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv2d_depthwise_padded', constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4), cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)', input_size=(2, 4, 6, 6), with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv2d_depthwise_dilated', constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4), cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)', input_size=(2, 4, 5, 5), with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Conv3d', constructor_args=(2, 3, (2, 3, 2)), cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})', input_size=(1, 2, 4, 5, 4), cudnn=True, check_with_long_tensor=True, with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( module_name='Conv3d', constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False), cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) .stride(1).padding(0).dilation(1).groups(1).bias(false)''', input_size=(1, 2, 3, 4, 5), cudnn=True, desc='no_bias', check_with_long_tensor=True, with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( module_name='Conv3d', constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False), cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4}) .stride(1).padding(0).dilation(1).groups(1).bias(false)''', input_size=(1, 2, 3, 4, 5), cudnn=True, desc='1x1x1_no_bias', check_with_long_tensor=False, with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( module_name='Conv3d', constructor_args=(3, 4, 2, 2), cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)', input_size=(2, 3, 5, 5, 5), cudnn=True, desc='stride', check_with_long_tensor=True, with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( module_name='Conv3d', constructor_args=(3, 4, 2, 2, 1), cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)', input_size=(2, 3, 5, 5, 5), cudnn=True, desc='stride_padding', check_with_long_tensor=True, with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( module_name='Conv3d', constructor_args=(3, 4, (2, 3, 4)), cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})', input_size=(0, 3, 3, 4, 5), cudnn=True, check_with_long_tensor=True, desc='zero_batch', with_tf32=True, ), dict( fullname='Conv3d_groups', constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2), cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)', input_size=(1, 2, 4, 5, 4), cudnn=True, check_with_long_tensor=True, with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( fullname='Conv3d_dilated', constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2), cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)', input_size=(2, 3, 5, 5, 5), with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( fullname='Conv3d_dilated_strided', constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2), cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)', input_size=(2, 3, 5, 5, 5), with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( fullname='Conv3d_pad_valid', constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"), cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)', input_size=(2, 3, 6, 5, 4), cudnn=True, with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( fullname='Conv3d_pad_same', constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"), cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)', input_size=(2, 3, 6, 5, 4), cudnn=True, with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( fullname='Conv3d_pad_same_dilated', constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2), cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)', input_size=(2, 3, 6, 5, 4), cudnn=True, with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( module_name='ConvTranspose3d', constructor_args=(2, 3, (2, 3, 2)), cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})', cudnn=True, input_size=(1, 2, 4, 5, 4), with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( module_name='ConvTranspose3d', constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)), cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2}) .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''', cudnn=True, input_size=(1, 2, 4, 5, 4), desc='dilated', with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( module_name='ReplicationPad3d', constructor_args=((1, 2, 3, 3, 2, 1),), cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', input_size=(2, 3, 2, 2, 2), default_dtype=torch.double, ), dict( module_name='ReplicationPad3d', constructor_args=((1, 2, 3, 3, 2, 1),), cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', input_size=(3, 2, 2, 2), reference_fn=single_batch_reference_fn, desc='no_batch_dim', default_dtype=torch.double, ), dict( module_name='ReplicationPad3d', constructor_args=((1, 2, 3, 3, 2, 1),), cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})', input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True), skip_half=True, desc='complex' ), dict( module_name='Embedding', constructor_args=(4, 3), cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), check_gradgrad=False, default_dtype=torch.double, decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971") ), dict( module_name='Embedding', constructor_args=(4, 3), cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), check_gradgrad=False, desc='discontiguous', default_dtype=torch.double, decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971") ), dict( module_name='EmbeddingBag', constructor_args=(4, 3), cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), check_gradgrad=False, desc='mean', default_dtype=torch.double, ), dict( module_name='EmbeddingBag', constructor_args=(4, 3), cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)', input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), check_gradgrad=False, desc='discontiguous', default_dtype=torch.double, ), dict( module_name='EmbeddingBag', constructor_args=(4, 3, None, 2., False, 'sum'), cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''', input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), check_gradgrad=False, desc='sum', default_dtype=torch.double, ), dict( module_name='EmbeddingBag', constructor_args=(4, 3, None, 2., False, 'max'), cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''', input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), check_gradgrad=False, desc='max', default_dtype=torch.double, ), dict( fullname='EmbeddingBag_mean_padding_idx', constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1), cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)', input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), check_gradgrad=False, default_dtype=torch.double, ), dict( fullname='EmbeddingBag_sum_padding_idx', constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1), cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''', input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), check_gradgrad=False, default_dtype=torch.double, ), dict( fullname='EmbeddingBag_max_padding_idx', constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1), cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''', input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]), check_gradgrad=False, default_dtype=torch.double, ), dict( fullname='EmbeddingBag_sparse', constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double), cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3) .sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))''', input_fn=lambda: torch.randperm(2).repeat(1, 2), check_gradgrad=False, has_sparse_gradients=True, ), dict( constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True), cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))', input_fn=lambda: torch.randperm(2).repeat(1, 2), fullname='Embedding_sparse', check_gradgrad=False, has_sparse_gradients=True, ), dict( module_name='PixelShuffle', constructor_args=(3,), cpp_constructor_args='torch::nn::PixelShuffleOptions(3)', input_size=(1, 9, 4, 4), default_dtype=torch.double, ), dict( module_name='PixelUnshuffle', constructor_args=(3,), cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)', input_size=(1, 1, 12, 12), default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', input_size=(1, 2, 4), fullname='interpolate_nearest_1d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', input_size=(0, 2, 4), fullname='interpolate_nearest_1d_zero_dim', pickle=False, ), dict( constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''', input_size=(1, 2, 3), fullname='interpolate_nearest_tuple_1d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt).scale_factor(std::vector({4.})).mode(torch::kNearest)''', input_size=(1, 2, 4), fullname='interpolate_nearest_scale_1d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12})) .scale_factor(std::nullopt) .mode(torch::kLinear) .align_corners(false)''', input_size=(1, 2, 4), fullname='interpolate_linear_1d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({4})) .scale_factor(std::nullopt) .mode(torch::kLinear) .align_corners(false)''', input_size=(1, 2, 3), fullname='interpolate_linear_tuple_1d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({4.})) .mode(torch::kLinear) .align_corners(false)''', input_size=(1, 2, 4), fullname='interpolate_linear_scale_1d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12})) .scale_factor(std::nullopt) .mode(torch::kLinear) .align_corners(false)''', input_size=(0, 2, 4), fullname='interpolate_linear_1d_zero_dim', pickle=False, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12})) .scale_factor(std::nullopt) .mode(torch::kLinear) .align_corners(true)''', input_size=(1, 2, 4), fullname='interpolate_linear_1d_align_corners', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({4.})) .mode(torch::kLinear) .align_corners(true)''', input_size=(1, 2, 4), fullname='interpolate_linear_scale_1d_align_corners', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({2, 2})) .scale_factor(std::nullopt) .mode(torch::kNearest)''', input_size=(1, 128, 1, 1), fullname='interpolate_nearest_2d_launch_configs', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 12})) .scale_factor(std::nullopt) .mode(torch::kNearest)''', input_size=(1, 2, 4, 4), fullname='interpolate_nearest_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 16})) .scale_factor(std::nullopt) .mode(torch::kNearest)''', input_size=(1, 2, 3, 4), fullname='interpolate_nearest_tuple_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({4., 4.})) .mode(torch::kNearest)''', input_size=(1, 2, 4, 4), fullname='interpolate_nearest_scale_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 12})) .scale_factor(std::nullopt) .mode(torch::kNearest)''', input_size=(0, 2, 4, 4), fullname='interpolate_nearest_2d_zero_dim', pickle=False, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 12})) .scale_factor(std::nullopt) .mode(torch::kBilinear) .align_corners(false)''', input_size=(1, 2, 4, 4), fullname='interpolate_bilinear_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 12})) .scale_factor(std::nullopt) .mode(torch::kBilinear) .align_corners(false)''', input_size=(0, 2, 4, 4), fullname='interpolate_bilinear_2d_zero_dim', pickle=False, ), dict( constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({4, 6})) .scale_factor(std::nullopt) .mode(torch::kBilinear) .align_corners(false)''', input_size=(1, 2, 2, 3), fullname='interpolate_bilinear_tuple_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bilinear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({4., 4.})) .mode(torch::kBilinear) .align_corners(false)''', input_size=(1, 2, 4, 4), fullname='interpolate_bilinear_scale_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.), mode='bilinear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({2., 2.})) .mode(torch::kBilinear) .align_corners(false)''', input_size=(1, 2, 4, 4), fullname='interpolate_bilinear_scale_tuple_shared_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), mode='bilinear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({2., 1.})) .mode(torch::kBilinear) .align_corners(false)''', input_size=(1, 2, 4, 4), fullname='interpolate_bilinear_scale_tuple_skewed_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({4, 6})) .scale_factor(std::nullopt) .mode(torch::kBilinear) .align_corners(true)''', input_size=(1, 2, 4, 4), fullname='interpolate_bilinear_tuple_2d_align_corners', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), mode='bilinear', align_corners=True), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({2., 1.})) .mode(torch::kBilinear) .align_corners(true)''', input_size=(1, 2, 4, 4), fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 12})) .scale_factor(std::nullopt) .mode(torch::kBicubic) .align_corners(false)''', input_size=(1, 2, 4, 4), fullname='interpolate_bicubic_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 12})) .scale_factor(std::nullopt) .mode(torch::kBicubic) .align_corners(false)''', input_size=(0, 2, 4, 4), fullname='interpolate_bicubic_2d_zero_dim', pickle=False, ), dict( constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({4, 6})) .scale_factor(std::nullopt) .mode(torch::kBicubic) .align_corners(false)''', input_size=(1, 2, 2, 3), fullname='interpolate_bicubic_tuple_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({4., 4.})) .mode(torch::kBicubic) .align_corners(false)''', input_size=(1, 2, 4, 4), fullname='interpolate_bicubic_scale_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.), mode='bicubic', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({2., 2.})) .mode(torch::kBicubic) .align_corners(false)''', input_size=(1, 2, 4, 4), fullname='interpolate_bicubic_scale_tuple_shared_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), mode='bicubic', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({2., 1.})) .mode(torch::kBicubic) .align_corners(false)''', input_size=(1, 2, 4, 4), fullname='interpolate_bicubic_scale_tuple_skewed_2d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({4, 6})) .scale_factor(std::nullopt) .mode(torch::kBicubic) .align_corners(true)''', input_size=(1, 2, 4, 4), fullname='interpolate_bicubic_tuple_2d_align_corners', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.), mode='bicubic', align_corners=True), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({2., 1.})) .mode(torch::kBicubic) .align_corners(true)''', input_size=(1, 2, 4, 4), fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 12, 12})) .scale_factor(std::nullopt) .mode(torch::kNearest)''', input_size=(1, 2, 4, 4, 4), fullname='interpolate_nearest_3d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 12, 12})) .scale_factor(std::nullopt) .mode(torch::kNearest)''', input_size=(0, 2, 4, 4, 4), fullname='interpolate_nearest_3d_zero_dim', pickle=False, ), dict( constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 16, 16})) .scale_factor(std::nullopt) .mode(torch::kNearest)''', input_size=(1, 2, 3, 4, 4), fullname='interpolate_nearest_tuple_3d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({4., 4., 4.})) .mode(torch::kNearest)''', input_size=(1, 2, 4, 4, 4), fullname='interpolate_nearest_scale_3d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 12, 12})) .scale_factor(std::nullopt) .mode(torch::kTrilinear) .align_corners(false)''', input_size=(1, 2, 4, 4, 4), fullname='interpolate_trilinear_3d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({12, 12, 12})) .scale_factor(std::nullopt) .mode(torch::kTrilinear) .align_corners(false)''', input_size=(0, 2, 4, 4, 4), fullname='interpolate_trilinear_3d_zero_dim', pickle=False, ), dict( constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None, mode='trilinear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({4, 6, 6})) .scale_factor(std::nullopt) .mode(torch::kTrilinear) .align_corners(false)''', input_size=(1, 2, 2, 3, 3), fullname='interpolate_trilinear_tuple_3d', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({3., 3., 3.})) .mode(torch::kTrilinear) .align_corners(false)''', input_size=(1, 2, 3, 4, 5), fullname='interpolate_trilinear_scale_3d', # See https://github.com/pytorch/pytorch/issues/5006 precision=3e-4, pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None, mode='trilinear', align_corners=True), cpp_options_args='''F::InterpolateFuncOptions() .size(std::vector({4, 6, 6})) .scale_factor(std::nullopt) .mode(torch::kTrilinear) .align_corners(true)''', input_size=(1, 2, 2, 3, 3), fullname='interpolate_trilinear_tuple_3d_align_corners', pickle=False, default_dtype=torch.double ), dict( constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True), cpp_options_args='''F::InterpolateFuncOptions() .size(std::nullopt) .scale_factor(std::vector({3., 3., 3.})) .mode(torch::kTrilinear) .align_corners(true)''', input_size=(1, 2, 3, 4, 4), fullname='interpolate_trilinear_scale_3d_align_corners', # See https://github.com/pytorch/pytorch/issues/5006 precision=3e-4, pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.softmax, dim=-1), cpp_options_args='F::SoftmaxFuncOptions(-1)', input_size=(2, 128), # trigger the last-dim algo in CUDA fullname='softmax_lastdim', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64), cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)', input_size=(2, 128), fullname='softmax_lastdim_dtype', pickle=False, test_cuda=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.softmax, dim=1), cpp_options_args='F::SoftmaxFuncOptions(1)', input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo fullname='softmax_spatial_special', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.softmax, dim=1), cpp_options_args='F::SoftmaxFuncOptions(1)', input_size=(2, 2, 4, 4), # regular spatial algorithm fullname='softmax_spatial', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64), cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)', input_size=(2, 2, 4, 4), # regular spatial algorithm fullname='softmax_spatial_dtype', pickle=False, test_cuda=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.softmax, dim=0), cpp_options_args='F::SoftmaxFuncOptions(0)', input_size=(2, 3, 4, 5), fullname='softmax_functional_dim0', test_cuda=False, pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.softmax, dim=3), cpp_options_args='F::SoftmaxFuncOptions(3)', input_size=(2, 3, 4, 5), fullname='softmax_functional_dim3', test_cuda=False, pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.softmax, dim=-1), cpp_options_args='F::SoftmaxFuncOptions(-1)', input_size=(), fullname='softmax_functional_scalar', test_cuda=False, pickle=False, ), dict( constructor=wrap_functional(F.log_softmax, dim=-1), cpp_options_args='F::LogSoftmaxFuncOptions(-1)', input_size=(2, 128), # trigger the last-dim algo in CUDA fullname='log_softmax_lastdim', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.log_softmax, dim=1), cpp_options_args='F::LogSoftmaxFuncOptions(1)', input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo fullname='log_softmax_spatial_special', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.log_softmax, dim=1), cpp_options_args='F::LogSoftmaxFuncOptions(1)', input_size=(2, 2, 4, 4), # regular spatial algorithm fullname='log_softmax_spatial', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.log_softmax, dim=0), cpp_options_args='F::LogSoftmaxFuncOptions(0)', input_size=(2, 3, 4, 5), fullname='log_softmax_dim0', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.log_softmax, dim=3), cpp_options_args='F::LogSoftmaxFuncOptions(3)', input_size=(2, 3, 4, 5), fullname='log_softmax_dim3', pickle=False, default_dtype=torch.double, ), dict( constructor=wrap_functional(F.log_softmax, dim=0), cpp_options_args='F::LogSoftmaxFuncOptions(0)', input_size=(), fullname='log_softmax_scalar', pickle=False, ), dict( fullname='Unfold', constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)), cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', input_size=(2, 4, 3, 3), check_gradgrad=False, test_cuda=True, default_dtype=torch.double, ), dict( fullname='Fold', constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', input_size=(2, 16, 4), check_gradgrad=False, test_cuda=True, default_dtype=torch.double, ), dict( fullname='Fold_no_batch_dim_input', constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})', input_size=(16, 4), check_gradgrad=False, ref=single_batch_reference_fn, test_cuda=True, default_dtype=torch.double, ), dict( fullname='Unfold_int_input', constructor=lambda: nn.Unfold(2, 1, 0, 1), cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)', input_size=(2, 4, 3, 3), check_gradgrad=False, test_cuda=True, default_dtype=torch.double, ), dict( fullname='Fold_int_input', constructor=lambda: nn.Fold(3, 2, 1, 0, 1), cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)', input_size=(2, 16, 4), check_gradgrad=False, test_cuda=True, default_dtype=torch.double, ), dict( fullname='Fold_no_batch_dim_int_input', constructor=lambda: nn.Fold(3, 2, 1, 0, 1), cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)', input_size=(16, 4), ref=single_batch_reference_fn, check_gradgrad=False, test_cuda=True, default_dtype=torch.double, ), dict( module_name='RReLU', constructor_args=(0.1, 0.9), cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)', input_size=(), desc='with_up_down_scalar', test_cuda=False, default_dtype=torch.double, ), dict( module_name='PairwiseDistance', input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), default_dtype=torch.double, ), dict( module_name='PairwiseDistance', input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)), desc='broadcast_lhs', default_dtype=torch.double, ), dict( module_name='PairwiseDistance', input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)), desc='broadcast_rhs', default_dtype=torch.double, ), dict( module_name='PairwiseDistance', constructor_args=(1.5, 1e-05, True), cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)', input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)), desc='with_non_default_args', default_dtype=torch.double, ), dict( module_name='PairwiseDistance', input_fn=lambda: (torch.randn(8), torch.randn(8)), reference_fn=single_batch_reference_fn, desc='no_batch_dim', default_dtype=torch.double, ), dict( module_name='TransformerEncoderLayer', constructor_args=(4, 2, 16, 0.0), cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) .dim_feedforward(16) .dropout(0.0)''', input_size=(2, 3, 4), desc='relu_activation', with_tf32=True, tf32_precision=0.1, # TODO(#50743): figure out the error # RuntimeError: The size of tensor a (6) must match the size of tensor b (4) # at non-singleton dimension 2 check_batched_grad=False, check_gradgrad=False, default_dtype=torch.double, ), dict( module_name='TransformerEncoderLayer', constructor_args=(4, 2, 8, 0.0, F.gelu), cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2) .dim_feedforward(8) .dropout(0.0) .activation(torch::kGELU)''', input_size=(2, 3, 4), check_gradgrad=False, desc='gelu_activation', with_tf32=True, tf32_precision=0.08 if SM90OrLater else 0.05, default_dtype=torch.double, ), dict( module_name='TransformerDecoderLayer', constructor_args=(4, 2, 8, 0.0), cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) .dim_feedforward(8) .dropout(0.0)''', input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), check_gradgrad=False, desc='relu_activation', with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( module_name='TransformerDecoderLayer', constructor_args=(4, 2, 8, 0.0, F.gelu), cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2) .dim_feedforward(8) .dropout(0.0) .activation(torch::kGELU)''', input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)), check_gradgrad=False, desc='gelu_activation', with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), dict( module_name='Transformer', constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu), cpp_constructor_args='''torch::nn::TransformerOptions() .d_model(4) .nhead(2) .num_encoder_layers(2) .num_decoder_layers(2) .dim_feedforward(8) .dropout(0.0) .activation(torch::kReLU)''', input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)), check_gradgrad=False, desc='multilayer_coder', with_tf32=True, tf32_precision=0.05 if SM90OrLater else 0.03, default_dtype=torch.double, ), dict( module_name='Linear', constructor_args=(3, 5), cpp_constructor_args='torch::nn::LinearOptions(3, 5)', input_fn=lambda: torch.rand(3), reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1], desc="no_batch_dim", with_tf32=True, tf32_precision=0.005, default_dtype=torch.double, ), dict( module_name='Flatten', cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)', constructor_args=(-3, -1), input_size=(3, 4, 5), reference_fn=single_batch_reference_fn, desc="no_batch_dim", default_dtype=torch.double, ), dict( module_name='Unflatten', cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})', constructor_args=(-2, torch.Size([2, 2])), input_size=(3, 4, 5), reference_fn=single_batch_reference_fn, desc="no_batch_dim", default_dtype=torch.double, ), dict( module_name='LayerNorm', constructor_args=([56, 56, 56], 1e-5, False), cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)', input_size=(4, 56, 56, 56), cudnn=True, check_eval=True, gradcheck_fast_mode=True, check_half=True, desc='3d_no_affine_large_feature', ), ] # add conv padding mode tests: for padding_mode, cpp_padding_mode in zip( ['reflect', 'circular', 'replicate', 'zeros'], ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']): # conv signature: # in_channels, out_channels, kernel_size, stride=1, # padding=0, dilation=1, groups=1, # bias=True, padding_mode='zeros' for d in (1, 2, 3): if d == 3 and padding_mode == 'reflect': # FIXME: remove after implementing reflection pad 3d # https://github.com/pytorch/pytorch/issues/27655 continue padding = tuple(range(1, d + 1)) cpp_padding = '{' + ', '.join(map(str, padding)) + '}' input_size = (2, 2) + (4,) * d output_size = (2, 3) + tuple(p + 1 for p in padding) # simplified from `(4 + 2 * p - 3) // 2 + 1` new_module_tests.append( dict( module_name=f'Conv{d}d', constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode), cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3) .stride(2) .padding({cpp_padding}) .dilation(1) .groups(1) .bias(true) .padding_mode({cpp_padding_mode})''', input_size=input_size, output_size=output_size, cudnn=True, desc=f'{padding_mode}_stride2_pad2', with_tf32=True, tf32_precision=0.05, default_dtype=torch.double, ), ) # Check that non linear activations work with no batch dimensions non_linear_activations_no_batch = [ 'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU', 'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU', 'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh', 'Tanhshrink', 'Threshold' ] non_linear_activations_extra_info: Dict[str, dict] = { 'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double}, 'Threshold': {'constructor_args': (2., 1.)}, 'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double}, 'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double}, # For RRelu, test that compare CPU and GPU results fail because RNG # is different between CPU and GPU 'RReLU': {'test_cuda': False, 'default_dtype': torch.double}, 'ELU': {'default_dtype': torch.double}, 'GELU': {'default_dtype': torch.double}, 'GLU': {'default_dtype': torch.double}, 'Hardshrink': {'default_dtype': torch.double}, 'Hardtanh': {'default_dtype': torch.double}, 'LeakyReLU': {'default_dtype': torch.double}, 'LogSigmoid': {'default_dtype': torch.double}, 'Mish': {'default_dtype': torch.double}, 'PReLU': {'default_dtype': torch.double}, 'ReLU6': {'default_dtype': torch.double}, 'ReLU': {'default_dtype': torch.double}, 'SELU': {'default_dtype': torch.double}, 'SiLU': {'default_dtype': torch.double}, 'Sigmoid': {'default_dtype': torch.double}, 'Softplus': {'default_dtype': torch.double}, 'Softshrink': {'default_dtype': torch.double}, 'Softsign': {'default_dtype': torch.double}, 'Tanh': {'default_dtype': torch.double}, 'Tanhshrink': {'default_dtype': torch.double}, } for non_linear_activation in non_linear_activations_no_batch: activation_test_info = dict( module_name=non_linear_activation, input_size=(4,), reference_fn=single_batch_reference_fn, desc='no_batch_dim', test_cpp_api_parity=False, ) extra_info = non_linear_activations_extra_info.get(non_linear_activation, {}) activation_test_info.update(extra_info) new_module_tests.append(activation_test_info) return new_module_tests def kldivloss_reference(input, target, reduction='mean', log_target=False): if log_target: result = torch.exp(target) * (target - input) else: result = target * (target.log() - input) if reduction == 'mean': return result.mean() elif reduction == 'sum': return result.sum() elif reduction == 'batchmean' and result.dim() != 0: return result.sum() / result.size(0) return result def nlllossNd_reference(input, target, weight=None, ignore_index=-100, reduction='mean'): assert input.dim() >= 3 N = input.size(0) C = input.size(1) out_size = (N,) + input.size()[2:] output = torch.zeros(out_size).type_as(input) if weight is None: weight = torch.ones(C).type_as(input) total_weight = 0 for tup in product(*[range(size) for size in out_size]): t_nx = target[tup] norm = 0. if ignore_index == t_nx else weight[t_nx].item() input_index = list(tup) input_index.insert(1, t_nx) output[tup] = -input[tuple(input_index)] * norm total_weight += norm if reduction == 'mean': return output.sum() / total_weight elif reduction == 'sum': return output.sum() return output def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean', label_smoothing=0.0): assert input.dim() >= 2 input = torch.log_softmax(input, 1) C = input.size(1) if weight is None: weight = torch.ones(C).type_as(input) weight = weight.view(1, C, *(1 for _ in input.shape[2:])) if label_smoothing > 0.0: assert label_smoothing <= 1.0 target = (target * (1 - label_smoothing) + label_smoothing / C) output = -(input * target * weight).sum(dim=1) if reduction == 'mean': return output.mean() elif reduction == 'sum': return output.sum() return output def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): log_softmax_input = torch.log_softmax(input, 1) nllloss = F.nll_loss( log_softmax_input, target, weight, ignore_index=ignore_index, reduction=reduction) if label_smoothing == 0.0: return nllloss assert 0.0 < label_smoothing <= 1.0 input = torch.log_softmax(input, 1) C = input.size(1) if weight is not None: input = input * weight.view(1, C, *(1 for _ in input.shape[2:])) smooth_loss = -torch.sum(input, 1) ignore_mask = target == ignore_index smooth_loss.masked_fill_(ignore_mask, 0.0) if reduction == 'mean': if weight is not None: # TODO: This code can path can be removed if #61309 is resolved # loss is normalized by the weights to be consistent with nll_loss_nd ret = torch.sum(smooth_loss) / weight.gather(0, target.masked_select(ignore_mask.logical_not()).flatten()).sum() else: ret = torch.mean(smooth_loss.masked_select(ignore_mask.logical_not())) elif reduction == 'sum': ret = torch.sum(smooth_loss) else: ret = smooth_loss return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C) def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): if input.shape == target.shape: return cross_entropy_loss_prob_target_reference( input, target, weight=weight, reduction=reduction, label_smoothing=label_smoothing) else: return cross_entropy_loss_indices_target_reference( input, target, weight=weight, reduction=reduction, ignore_index=ignore_index, label_smoothing=label_smoothing ) def nllloss_reference(input, target, weight=None, ignore_index=-100, reduction='mean'): def nll_loss_helper(input, target, weight, ignore_index): if target == ignore_index: return (0, 0) norm = 1 if weight is None else weight[target] result = -input[target] * norm return (result, norm) losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index) for i, t in zip(input, target)] losses, weights = zip(*losses_and_weights) losses_tensor = input.new_tensor(losses) if reduction == 'mean': return sum(losses_tensor) / sum(weights) elif reduction == 'sum': return sum(losses_tensor) else: return losses_tensor def smoothl1loss_reference(input, target, reduction='mean', beta=1.0): abs_diff = (input - target).abs() ge_beta_mask = (abs_diff >= beta).type_as(abs_diff) lt_beta_mask = (abs_diff < beta).type_as(abs_diff) # when beta <= 0 we should just use l1_loss if beta == 0: output = abs_diff else: output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta if reduction == 'mean': return output.mean() elif reduction == 'sum': return output.sum() return output def huberloss_reference(input, target, reduction='mean', delta=1.0): abs_diff = (input - target).abs() ge_delta_mask = (abs_diff >= delta) lt_delta_mask = (abs_diff < delta) output = ge_delta_mask * delta * (abs_diff - 0.5 * delta) + lt_delta_mask * 0.5 * (abs_diff ** 2) if reduction == 'mean': return output.mean() elif reduction == 'sum': return output.sum() return output def _multilabelmarginloss_reference(input, target): targets = [] for target_index in target: if target_index < 0: break targets.append(target_index) sum = 0 for target_index in targets: for i in range(0, len(input)): if i not in targets: sum += max(0, 1 - input[target_index] + input[i]) return sum def multilabelmarginloss_reference(input, target, reduction='mean'): # make everything 2-dimensional input_dim = input.dim() if input.dim() < 2: assert target.dim() < 2 input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0) target = target.unsqueeze(0) if target.dim() == 1 else target.unsqueeze(0).unsqueeze(0) n = input.size(0) dim = input.size(1) output = input.new(n).zero_() for i in range(0, n): output[i] = _multilabelmarginloss_reference(input[i], target[i]) if reduction == 'mean': return output.mean() / dim elif reduction == 'sum': return output.sum() / dim elif input_dim < 2: # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us # back to correct dimensionality return output.squeeze() / dim else: return output / dim def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'): margin_clamp = (margin - input).clamp(min=0).type_as(input) output = torch.where(target == 1, input, margin_clamp) if reduction == 'mean': return output.mean() elif reduction == 'sum': return output.sum() return output def softmarginloss_reference(input, target, reduction='mean'): output = (1 + (-input * target).exp()).log() if reduction == 'mean': return output.mean() elif reduction == 'sum': return output.sum() return output def _multimarginloss_reference(input, target_idx, p, margin, weight): if weight is None: weight = input.new(len(input)).fill_(1) output = 0 for i in range(0, len(input)): if i != target_idx: output += weight[target_idx] * (max(0, (margin - input[target_idx] + input[i])) ** p) return output def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'): if input.dim() < 2: input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0) target_dim = target.dim() if target.dim() == 0: target = target.unsqueeze(0) n = input.size(0) dim = input.size(1) output = input.new(n) for x in range(0, n): output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight) if reduction == 'mean': return output.mean() / dim elif reduction == 'sum': return output.sum() / dim elif target_dim == 0: return output.squeeze(0) / dim return output / dim def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'): def _cos(a, b): cos = a.new(a.size(0)) for i in range(0, a.size(0)): cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5) return cos output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0)) if reduction == 'mean': return output.mean() elif reduction == 'sum': return output.sum() return output def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, reduction='mean'): d_p = torch.pairwise_distance(anchor, positive, p, eps) d_n = torch.pairwise_distance(anchor, negative, p, eps) if swap: d_s = torch.pairwise_distance(positive, negative, p, eps) d_n = torch.min(d_n, d_s) output = torch.clamp(margin + d_p - d_n, min=0.0) if reduction == 'mean': return output.mean() elif reduction == 'sum': return output.sum() return output def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'): output = (-target * (input1 - input2) + margin).clamp(min=0) if reduction == 'mean': return output.mean() elif reduction == 'sum': return output.sum() return output # this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'): input_lengths = torch.as_tensor(input_lengths, dtype=torch.long) target_lengths = torch.as_tensor(target_lengths, dtype=torch.long) dt = log_probs.dtype log_probs = log_probs.double() # we need the accuracy as we are not in logspace targets = targets.long() cum_target_lengths = target_lengths.cumsum(0) losses = [] for i in range(log_probs.size(1)): input_length = input_lengths[i].item() target_length = target_lengths[i].item() cum_target_length = cum_target_lengths[i].item() targets_prime = targets.new_full((2 * target_length + 1,), blank) if targets.dim() == 2: targets_prime[1::2] = targets[i, :target_length] else: targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length] probs = log_probs[:input_length, i].exp() alpha = log_probs.new_zeros((target_length * 2 + 1,)) alpha[0] = probs[0, blank] alpha[1] = probs[0, targets_prime[1]] mask_third = (targets_prime[:-2] != targets_prime[2:]) for t in range(1, input_length): alpha_next = alpha.clone() alpha_next[1:] += alpha[:-1] alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1)) alpha = probs[t, targets_prime] * alpha_next losses.append(-alpha[-2:].sum().log()[None]) output = torch.cat(losses, 0) if reduction == 'mean': output = (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean() elif reduction == 'sum': output = output.sum() output = output.to(dt) return output loss_reference_fns: Dict['str', Callable] = { 'KLDivLoss': kldivloss_reference, 'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True), 'NLLLoss': nllloss_reference, 'NLLLossNd': nlllossNd_reference, 'SmoothL1Loss': smoothl1loss_reference, 'HuberLoss': huberloss_reference, 'MultiLabelMarginLoss': multilabelmarginloss_reference, 'HingeEmbeddingLoss': hingeembeddingloss_reference, 'SoftMarginLoss': softmarginloss_reference, 'MultiMarginLoss': multimarginloss_reference, 'CosineEmbeddingLoss': cosineembeddingloss_reference, 'TripletMarginLoss': tripletmarginloss_reference, 'MarginRankingLoss': marginrankingloss_reference, 'CTCLoss': ctcloss_reference, 'CrossEntropyLoss': cross_entropy_loss_reference } criterion_tests = [] def single_batch_reference_criterion_fn(*args): """Reference function for criterion supporting no batch dimensions. The criterion is passed the input and target in batched form with a single item. The output is squeezed to compare with the no-batch input. """ criterion = args[-1] def unsqueeze_inp(inp): if isinstance(inp, (list, tuple)): return [t.unsqueeze(0) for t in inp] return inp.unsqueeze(0) def flatten(xs): result = [] if isinstance(xs, (list, tuple)): for x in xs: result.extend(flatten(x)) else: result.append(xs) return result single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]]) output = criterion(*single_batch_input_args) reduction = get_reduction(criterion) if reduction == 'none': return output.squeeze(0) # reduction is 'sum' or 'mean' which results in a scalar return output # Check that regression criterion work with no batch dimensions regression_criterion_no_batch = [ 'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss' ] reductions = ['none', 'mean', 'sum'] for name, reduction in product(regression_criterion_no_batch, reductions): regression_test_info = dict( fullname=f"{name}_no_batch_dim_{reduction}", constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction), input_size=(3, ), target_size=(3, ), reference_fn=single_batch_reference_criterion_fn, test_cpp_api_parity=False, default_dtype=torch.double, ) criterion_tests.append(regression_test_info) for reduction in reductions: regression_test_info = dict( fullname=f"KLDivLoss_no_batch_dim_{reduction}", constructor=lambda: nn.KLDivLoss(reduction=reduction), input_fn=lambda: torch.rand((3,)).log(), target_fn=lambda: torch.rand((3,)), reference_fn=single_batch_reference_criterion_fn, test_cpp_api_parity=False, default_dtype=torch.double, ) criterion_tests.append(regression_test_info) # Check that classification criterion work with no batch dimensions # List of tuples of (name, input_fn, target_fn) classification_criterion_no_batch = [ ( 'BCELoss', lambda: torch.sigmoid(torch.randn(9, dtype=torch.double)), lambda: torch.randn(9, dtype=torch.double).gt(0).to(torch.double) ), ('BCEWithLogitsLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9, dtype=torch.double)), ('HingeEmbeddingLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)), ('MultiLabelMarginLoss', lambda: torch.randn(4, dtype=torch.double), lambda: torch.tensor([3, 0, -1, 1])), ('SoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)), ('NLLLoss', lambda: F.log_softmax(torch.randn(3, dtype=torch.double), dim=0), lambda: torch.tensor(1)), ( 'CosineEmbeddingLoss', lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)), lambda: torch.tensor(1, dtype=torch.double) ), # For MarginRankingLoss, input_fn : (x1, x2) and target_fn : target ('MarginRankingLoss', lambda: (torch.randn(()), torch.randn(())), lambda: torch.randn(()).sign()), # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative ( 'TripletMarginLoss', lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)), lambda: torch.randn(9, dtype=torch.double) ), ('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)), ] classification_criterion_no_batch_extra_info: Dict[str, dict] = { 'MultiLabelMarginLoss': {'check_gradgrad': False}, } # TODO : Fix these discrepancies classification_cpp_parity = { 'BCELoss': False, 'BCEWithLogitsLoss': False, 'HingeEmbeddingLoss': False, 'NLLLoss': False, 'SoftMarginLoss': False, } reductions = ['none', 'mean', 'sum'] for (name, input_fn, target_fn), reduction in product(classification_criterion_no_batch, reductions): classification_test_info = dict( fullname=f"{name}_no_batch_dim_{reduction}", constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction), input_fn=lambda f=input_fn: f(), target_fn=lambda f=target_fn: f(), reference_fn=single_batch_reference_criterion_fn, test_cpp_api_parity=True, has_parity=classification_cpp_parity.get(name, True) ) extra_info = classification_criterion_no_batch_extra_info.get(name, {}) classification_test_info.update(extra_info) criterion_tests.append(classification_test_info) class NNTestCase(TestCase): # _forward is defined in classes inheriting from NNTestCase @abstractmethod def _forward(self, *args, **kwargs): raise NotImplementedError @abstractmethod def _get_parameters(self, module: nn.Module) -> Tuple[List[nn.Parameter], List[nn.Parameter]]: raise NotImplementedError @abstractmethod def _zero_grad_parameters(self, module: nn.Module) -> None: raise NotImplementedError @abstractmethod def _backward(self, module: nn.Module, input: _TensorOrTensors, output: torch.Tensor, grad_output: Union[torch.Tensor, Sequence[torch.Tensor]], create_graph: bool = False): raise NotImplementedError def _jacobian(self, input, num_out): if isinstance(input, tuple): return tuple(self._jacobian(elem, num_out) for elem in input) elif isinstance(input, list): return [self._jacobian(elem, num_out) for elem in input] else: return torch.zeros(input.nelement(), num_out) def _flatten_tensors(self, x): if isinstance(x, torch.Tensor): if x.is_sparse: return x.to_dense().view(-1) else: return x.view(-1) else: return tuple(self._flatten_tensors(a) for a in x) def _zero_grad_input(self, input): if isinstance(input, torch.Tensor): if input.requires_grad and input.grad is not None: input.grad.zero_() input.grad.detach_() else: for i in input: self._zero_grad_input(i) def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True): output = self._forward(module, input) output_size = output.nelement() if jacobian_input: jacobian_inp = self._jacobian(input, output_size) flat_jacobian_input = list(_iter_tensors(jacobian_inp)) if jacobian_parameters: num_param = sum(p.numel() for p in self._get_parameters(module)[0]) jacobian_param = torch.zeros(num_param, output_size) for i in range(output_size): param, d_param = self._get_parameters(module) # make non grad zeros d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param)] d_out = torch.zeros_like(output) flat_d_out = d_out.view(-1) flat_d_out[i] = 1 if jacobian_parameters: self._zero_grad_parameters(module) # Tensors will accumulate gradient from multiple steps if jacobian_input: self._zero_grad_input(input) d_input = self._backward(module, input, output, d_out) if jacobian_input: for jacobian_x, d_x in zip(flat_jacobian_input, _iter_tensors(d_input)): jacobian_x[:, i] = d_x.contiguous().view(-1) if jacobian_parameters: jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0) res: Tuple[torch.Tensor, ...] = () if jacobian_input: res += jacobian_inp, if jacobian_parameters: res += jacobian_param, return res def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True): def fw(*input): return self._forward(module, input).detach() res: Tuple[torch.Tensor, ...] = () if jacobian_input: res += _get_numerical_jacobian(fw, input, eps=1e-6), if jacobian_parameters: param, _ = self._get_parameters(module) to_cat = [] for p in param: jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6) # get_numerical_jacobian returns a list of tuples but we require a tensor to_cat.append(jacobian[0][0]) res += (torch.cat(to_cat, 0),) return res def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True): jacobian_parameters = bool(self._get_parameters(module)[0]) analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters) numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters) analytical_t = list(_iter_tensors(analytical)) numerical_t = list(_iter_tensors(numerical)) differences = [] for a, n in zip(analytical_t, numerical_t): if a.numel() != 0: differences.append(a.add(n, alpha=-1).abs().max()) # TODO: compare structure (ensure analytic jacobian has correct shape) if len(differences) > 0: self.assertLessEqual(max(differences), PRECISION) # type: ignore[type-var] class TestBase: _required_arg_names = {'constructor_args', 'input', 'extra_args'} def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs): self.desc = desc self.fullname = fullname self.constructor = constructor self.reference_fn = reference_fn for name in self._required_arg_names: if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs: if name in {'constructor_args', 'extra_args'}: kwargs[name] = () else: raise ValueError(f"{self.get_name()}: Specify {name} by a value, a function to generate it, or it's size!") self._extra_kwargs = kwargs self._arg_cache = {} def get_name(self): if self.fullname is not None: return 'test_' + self.fullname test_name = 'test_' + self.constructor.__name__ if self.desc: test_name += '_' + self.desc return test_name def _unpack(self, value): if isinstance(value, torch.Tensor): return value elif is_iterable(value): return type(value)(self._unpack(v) for v in value) else: return value @property def constructor_args(self): return self._get_arg('constructor_args', True) @property def extra_args(self): return self._get_arg('extra_args', True) def _get_arg(self, name, unpack): assert name in self._required_arg_names if name not in self._arg_cache: fn_name = name + '_fn' size_name = name + '_size' if name in self._extra_kwargs: self._arg_cache[name] = self._extra_kwargs[name] elif fn_name in self._extra_kwargs: self._arg_cache[name] = self._extra_kwargs[fn_name]() else: assert size_name in self._extra_kwargs, \ f"Missing `{name}`, `{size_name}` or `{fn_name}` for {self.get_name()}" def map_tensor_sizes(sizes): if isinstance(sizes, list): return [map_tensor_sizes(s) for s in sizes] elif isinstance(sizes, torch.Tensor): return sizes.double() else: return torch.randn(sizes) self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name]) return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name] def _get_input(self, unpack=True): return self._get_arg('input', unpack) def __call__(self, test_case): raise NotImplementedError class ModuleTest(TestBase): @abstractmethod def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any: raise NotImplementedError def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.jacobian_input = kwargs.get('jacobian_input', True) self.should_test_cuda = kwargs.get('test_cuda', True) self.should_test_pickle = kwargs.get('pickle', True) self.check_gradgrad = kwargs.get('check_gradgrad', True) self.FIXME_no_cuda_gradgrad_comparison = \ kwargs.get('FIXME_no_cuda_gradgrad_comparison', False) self.precision = kwargs.get('precision', 2e-4) self.check_forward_only = kwargs.get('check_forward_only', False) self.default_dtype = kwargs.get('default_dtype', None) if self.default_dtype is None: self.default_dtype = torch.get_default_dtype() def __call__(self, test_case): with set_default_dtype(self.default_dtype): module = self.constructor(*self.constructor_args) input = self._get_input() if self.reference_fn is not None: out = test_case._forward(module, input) ref_input = deepcopy(input) ref_module = deepcopy(module) expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0], ref_module) test_case.assertEqual(out, expected_out, exact_dtype=False) if self.check_forward_only: return self.test_noncontig(test_case, module, input) if self.should_test_pickle: # TODO: do this with in-memory files as soon as torch.save will support it with tempfile.TemporaryFile() as f: test_case._forward(module, input) torch.save(module, f) f.seek(0) # weights_only=False as this is legacy code that saves the model module_copy = torch.load(f, weights_only=False) test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input)) self._do_test(test_case, module, input) def noncontiguize(self, obj): if isinstance(obj, list): return [self.noncontiguize(o) for o in obj] elif isinstance(obj, tuple): return tuple(self.noncontiguize(o) for o in obj) tensor = obj ndim = tensor.dim() # Always making only the last dimension noncontiguous is easy to hide # bugs because .view(-1) will still work. So try to find a dim with size # > 1 and make that non-contiguous, i.e., stack + select on the # dimension directly after that. dim = ndim for d in range(ndim): if tensor.size(d) > 1: dim = d + 1 break noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach() assert noncontig.numel() == 1 or noncontig.numel() == 0 or not noncontig.is_contiguous() noncontig.requires_grad = tensor.requires_grad return noncontig def test_noncontig(self, test_case, module, input): # check no scalars, can't make non-contig if isinstance(input, torch.Tensor) and input.dim() == 0: return if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)): return test_case._zero_grad_parameters(module) test_case._zero_grad_input(input) with freeze_rng_state(): output = test_case._forward(module, input) if getattr(module, "return_indices", False): output = output[0] grad_output = output.new(output.shape).normal_() output = output.clone() d_input = deepcopy(test_case._backward(module, input, output, grad_output)) d_param = deepcopy(test_case._get_parameters(module)[1]) nc_input = self.noncontiguize(input) nc_grad_output = self.noncontiguize(grad_output) for contig_i, contig_g in product((True, False), repeat=2): i = input if contig_i else nc_input # Some ops, e.g., nn.Flatten, return gradient that shares # storage with the grad_output. Hence we copy here. go = deepcopy(grad_output if contig_g else nc_grad_output) test_case._zero_grad_parameters(module) test_case._zero_grad_input(i) with freeze_rng_state(): out = test_case._forward(module, i) if getattr(module, "return_indices", False): out = out[0] grad = test_case._backward(module, i, out, go) test_case.assertEqual(out, output) test_case.assertEqual(grad, d_input, atol=1e-4, rtol=0) test_case.assertEqual(test_case._get_parameters(module)[1], d_param) def test_cuda(self, test_case): if not TEST_CUDA or not self.should_test_cuda: raise unittest.SkipTest('Excluded from CUDA tests') with set_default_dtype(self.default_dtype): cpu_input = self._get_input() type_map = {torch.double: torch.float} cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,) is_any_input_complex = any(isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple) gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map) cpu_module = self.constructor(*self.constructor_args) gpu_module = self.constructor(*self.constructor_args).float().cuda() cpu_param = test_case._get_parameters(cpu_module) gpu_param = test_case._get_parameters(gpu_module) for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]): gpu_p.data.copy_(cpu_p) test_case._zero_grad_input(cpu_input_tuple) test_case._zero_grad_input(gpu_input_tuple) test_case._zero_grad_parameters(cpu_module) test_case._zero_grad_parameters(gpu_module) cpu_output = test_case._forward(cpu_module, cpu_input_tuple) gpu_output = test_case._forward(gpu_module, gpu_input_tuple) if getattr(cpu_module, "return_indices", False): cpu_output = cpu_output[0] gpu_output = gpu_output[0] test_case.assertEqual(cpu_output, gpu_output, atol=self.precision, rtol=0, exact_dtype=False) # Run backwards on CPU and GPU and compare results for _ in range(5): cpu_gradOutput = cpu_output.clone().normal_() gpu_gradOutput = cpu_gradOutput.type_as(gpu_output) cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput) gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput) test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False) for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]): test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0) # Run double-backwards on CPU and GPU and compare results if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison: cpu_output = cpu_module(*cpu_input_tuple) gpu_output = gpu_module(*gpu_input_tuple) if getattr(cpu_module, "return_indices", False): cpu_output = cpu_output[0] gpu_output = gpu_output[0] cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True) gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach() gpu_gradOutput.requires_grad = True cpu_gradInputs = torch.autograd.grad( cpu_output, cpu_input_tuple + tuple(cpu_module.parameters()), cpu_gradOutput, create_graph=True) gpu_gradInputs = torch.autograd.grad( gpu_output, gpu_input_tuple + tuple(gpu_module.parameters()), gpu_gradOutput, create_graph=True) for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs): test_case.assertEqual(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0, exact_dtype=False) # We mix output into the second backwards computation so that # torch.autograd.grad doesn't complain that some inputs # are unreachable (which can happen if you differentiate # only on the gradient. if is_any_input_complex: outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs) outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs) else: outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs) outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs) cpu_gg = torch.autograd.grad( outputs_cpu, cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()), retain_graph=True) gpu_gg = torch.autograd.grad( outputs_gpu, gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()), retain_graph=True) test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False) for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg): test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0, exact_dtype=False) self.test_noncontig(test_case, gpu_module, gpu_input_tuple) class InputVariableMixin: def _get_input(self): input = TestBase._get_input(self, False) # type: ignore[arg-type] def map_variables(i): if isinstance(i, torch.Tensor): if i.is_floating_point() or i.is_complex(): i.requires_grad = True return i else: return type(i)(map_variables(elem) for elem in i) return map_variables(input) class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cudnn = kwargs.get('cudnn', False) self.check_inplace = kwargs.get('check_inplace', False) self.check_gradgrad = kwargs.get('check_gradgrad', True) self.skip_double = kwargs.get('skip_double', False) self.skip_half = kwargs.get('skip_half', False) self.with_tf32 = kwargs.get('with_tf32', False) self.tf32_precision = kwargs.get('tf32_precision', 0.001) self.test_cpu = kwargs.get('test_cpu', True) self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False) self.check_batched_grad = kwargs.get('check_batched_grad', True) self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None) self.supports_forward_ad = kwargs.get('supports_forward_ad', False) self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False) def _check_gradients(self, test_case, module, input_tuple): params = tuple(x for x in module.parameters()) num_inputs = len(input_tuple) def fn_to_gradcheck(*inputs_and_params, **kwargs): assert not kwargs return test_case._forward(module, inputs_and_params[:num_inputs]) # gradcheck doesn't support operators that take in dense inputs but # return sparse parameters. This only happens in the case of nn.Embedding # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which # is a slightly different version of gradcheck that can handle this. if self.has_sparse_gradients: assert num_inputs == 1 test_input_jacobian = torch.is_floating_point(input_tuple[0]) test_case.check_jacobian(module, input_tuple[0], test_input_jacobian) else: test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params, check_batched_grad=self.check_batched_grad, fast_mode=self.gradcheck_fast_mode, check_forward_ad=self.supports_forward_ad)) if self.check_gradgrad: test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params, check_batched_grad=self.check_batched_grad, fast_mode=self.gradcheck_fast_mode, check_fwd_over_rev=self.supports_fwgrad_bwgrad)) def _do_test(self, test_case, module, input): num_threads = torch.get_num_threads() torch.set_num_threads(1) input_tuple = input if isinstance(input, tuple) else (input,) self._check_gradients(test_case, module, input_tuple) # check if module can be printed module.__repr__() if self.check_inplace: # check if the inplace variant of the module gives the same result # as the out-of-place # check_inplace doesn't support multiple input tensors, since we don't have any modules # that modify the inputs in-place and that accept more than one input assert len(input_tuple) == 1 input = input_tuple[0] module_ip = self.constructor(*self.constructor_args, inplace=True) input_version = input._version with freeze_rng_state(): output = module(input) test_case.assertEqual(input._version, input_version) input_ip = deepcopy(input) input_ip_clone = input_ip.clone() with freeze_rng_state(): output_ip = module_ip(input_ip_clone) test_case.assertNotEqual(input_ip_clone._version, input_version) test_case.assertEqual(output, output_ip) grad = output.data.clone().normal_() if input.grad is not None: with torch.no_grad(): input.grad.zero_() if input_ip.grad is not None: with torch.no_grad(): input_ip.grad.zero_() output.backward(grad) output_ip.backward(grad) test_case.assertEqual(input.grad, input_ip.grad) def assert_module_parameters_are(tensor_type, device_id=None): for p in module.parameters(): test_case.assertIsInstance(p, tensor_type) if device_id is not None: test_case.assertEqual(p.get_device(), device_id) if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA: # check that cuda() moves module parameters to correct GPU device, # and that float() casts parameters correctly input_tuple = tuple(t.cuda() for t in input_tuple) module.float().cuda() module(*input_tuple) assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] if torch.cuda.device_count() > 1: input_tuple = tuple(t.cuda(1) for t in input_tuple) module.cuda(1) with torch.cuda.device(1): module(*input_tuple) assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined] else: # check that float()/double() casters work correctly def to_type(tensor, real, complex): if tensor.is_complex(): return tensor.to(complex) elif tensor.is_floating_point(): return tensor.to(real) else: return tensor def to_half(x): # TODO: torch.complex32 when properly supported return to_type(x, torch.float16, None) def to_single(x): return to_type(x, torch.float32, torch.complex64) def to_double(x): return to_type(x, torch.float64, torch.complex128) # to float input_tuple = tuple(to_single(t) for t in input_tuple) module.float() module(*input_tuple) assert_module_parameters_are(torch.FloatTensor) # and back to double input_tuple = tuple(to_double(t) for t in input_tuple) module.double() module(*input_tuple) assert_module_parameters_are(torch.DoubleTensor) if TEST_CUDA and self.should_test_cuda: # check that cuda() moves module parameters to correct GPU device, # and that float() casts parameters correctly # to GPU0 input_tuple = tuple(to_single(t).cuda() for t in input_tuple) module.float().cuda() module(*input_tuple) assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] # to CPU input_tuple = tuple(t.cpu() for t in input_tuple) module.cpu() module(*input_tuple) assert_module_parameters_are(torch.FloatTensor) # back to GPU0 input_tuple = tuple(t.cuda() for t in input_tuple) module.cuda() module(*input_tuple) assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] # test that forwards of module runs correctly without cuDNN if self.cudnn: with torch.backends.cudnn.flags(enabled=False): module(*input_tuple) assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined] if torch.cuda.device_count() >= 2: # test cross-GPU transfer works # to GPU1 input_tuple = tuple(t.cuda(1) for t in input_tuple) module.cuda(1) with torch.cuda.device(1): module(*input_tuple) assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined] if not self.skip_double: # test double() input_tuple = tuple(to_double(t).cuda() for t in input_tuple) module.double().cuda() module(*input_tuple) assert_module_parameters_are(torch.cuda.DoubleTensor, 0) # type: ignore[attr-defined] # test half() if not self.skip_half: input_tuple = tuple(to_half(t).cuda() for t in input_tuple) module.half().cuda() module(*input_tuple) assert_module_parameters_are(torch.cuda.HalfTensor, 0) # type: ignore[attr-defined] torch.set_num_threads(num_threads) def _get_target(self): return self._get_arg('target', False) @property def constructor_args(self): return self._get_arg('constructor_args', False) class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc] # TODO: check that criterions don't ignore grad_output _required_arg_names = TestBase._required_arg_names.union({'target'}) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.should_test_cuda = kwargs.get('test_cuda', True) self.check_forward_only = kwargs.get('check_forward_only', False) self.check_gradgrad = kwargs.get('check_gradgrad', True) self.check_half = kwargs.get('check_half', True) self.check_bfloat16 = kwargs.get('check_bfloat16', False) self.check_complex = kwargs.get('check_complex', False) self.test_cpu = kwargs.get('test_cpu', True) self.with_tf32 = kwargs.get('with_tf32', True) self.tf32_precision = kwargs.get('tf32_precision', 0.001) self.check_batched_grad = kwargs.get('check_batched_grad', True) self.default_dtype = kwargs.get('default_dtype', None) if self.default_dtype is None: self.default_dtype = torch.get_default_dtype() def __call__(self, test_case): with set_default_dtype(self.default_dtype): module = self.constructor(*self.constructor_args) input = self._get_input() # Check that these methods don't raise errors module.__repr__() str(module) target = self._get_target() if self.reference_fn is not None: out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args) ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,) expected_out = self.reference_fn(*ref_args) test_case.assertEqual(out, expected_out) if self.check_forward_only: return params = tuple(x for x in module.parameters()) if not isinstance(input, tuple): inputs = (input,) + params + (target,) def apply_fn(input, target, *params): return module(input, target) else: inputs = input + params + (target,) def apply_fn(input1, input2, target, *params): # type: ignore[misc] return module(input1, input2, target) gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad) if self.check_gradgrad: gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad) def test_cuda(self, test_case, dtype, extra_args=None): def convert_dtype(obj, dtype, requires_grad=False): if isinstance(obj, torch.Tensor): return obj.detach().to(dtype=dtype).requires_grad_(requires_grad) elif isinstance(obj, tuple): return tuple(convert_dtype(o, dtype, requires_grad) for o in obj) else: return obj if not TEST_CUDA or not self.should_test_cuda: raise unittest.SkipTest('Excluded from CUDA tests') with set_default_dtype(self.default_dtype): cpu_input = self._get_input() cpu_target = self._get_target() cpu_module = self.constructor(*self.constructor_args) gpu_module = self.constructor(*self.constructor_args) # Convert input, target and module parameters to dtype cpu_input = convert_dtype(cpu_input, dtype, True) if cpu_target.is_floating_point() or cpu_target.is_complex(): cpu_target = convert_dtype(cpu_target, dtype) cpu_module.type(dtype) gpu_module.type(dtype) # GPU setup gpu_input = to_gpu(cpu_input) gpu_target = to_gpu(cpu_target) gpu_module.cuda() # torch.HalfTensor doesn't support most operations, converting back to default if dtype in {torch.half, torch.bfloat16}: cpu_input = self._get_input() cpu_target = self._get_target() # Loss modules with weights require consistent input/module weight types cpu_module = self.constructor(*self.constructor_args) cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args) gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args) # dtype used to be able to be None, so set precision in this way instead of a precision map test_case.assertEqual(cpu_output, gpu_output, atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False) cpu_gradInput = test_case._backward_criterion( cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args) gpu_gradInput = test_case._backward_criterion( gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args) # dtype used to be able to be None, so set precision in this way instead of a precision map test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False) def _get_target(self): return self._get_arg('target', False) @property def constructor_args(self): return self._get_arg('constructor_args', False) @property def extra_args(self): return self._get_arg('extra_args', False) def _test_bfloat16_ops(test_case, op, device, inp_dims=(), prec=1e-2, scale_factor=None): # fp32 compute input1 = torch.randn(inp_dims, dtype=torch.float32, device=device, requires_grad=True) if scale_factor is not None: input1 = (torch.rand(inp_dims, dtype=torch.bfloat16, device=device) * scale_factor).float().requires_grad_() out1 = op(input1) grad_input1 = torch.randn_like(out1, device=device) out1.backward(grad_input1) # bfloat16 compute op_bfp16 = op.bfloat16() input2 = input1.detach().bfloat16().requires_grad_() grad_input2 = grad_input1.bfloat16() out2 = op_bfp16(input2) out2.backward(grad_input2) test_case.assertEqual(out1, out2, atol=prec, rtol=prec, exact_dtype=False) test_case.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=prec, exact_dtype=False) def _test_module_empty_input(test_case, module, inp, check_size=True, inference=False): if not inference: inp.requires_grad_(True) out = module(inp) if not inference: gO = torch.rand_like(out) out.backward(gO) if check_size: test_case.assertEqual(out.size(), inp.size()) if not inference: for p in module.parameters(): if p.requires_grad: test_case.assertEqual(p.grad, torch.zeros_like(p.grad)) test_case.assertEqual(inp.grad, torch.zeros_like(inp)) def _create_basic_net(): class Layer(nn.Module): def __init__(self) -> None: super().__init__() self.layer_dummy_param = nn.Parameter(torch.empty(3, 5)) self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7)) class Net(nn.Module): def __init__(self) -> None: super().__init__() self.l1 = Layer() self.dummy_param = nn.Parameter(torch.empty(3, 5)) self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1)) l = Layer() n = Net() s = nn.Sequential(n, n) return l, n, s