import inspect import torch from unittest import mock from unittest.mock import MagicMock, patch from torch.testing import floating_types from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.testing._internal.common_utils import TestCase, run_tests # Returns a database of args & kwargs that can be used to construct each module. # Each entry is in class -> (args, kwargs) format. # Example: torch.nn.Linear -> ([10, 5], {}) # TODO: Merge this in with the initial ModuleInfo implementation. def build_constructor_arg_db(): return { torch.nn.AdaptiveAvgPool1d: ((5,), {}), torch.nn.AdaptiveAvgPool2d: ((5,), {}), torch.nn.AdaptiveAvgPool3d: ((5,), {}), torch.nn.AdaptiveLogSoftmaxWithLoss: ((100, 20, [5, 10, 15]), {}), torch.nn.AdaptiveMaxPool1d: ((5,), {}), torch.nn.AdaptiveMaxPool2d: ((5,), {}), torch.nn.AdaptiveMaxPool3d: ((5,), {}), torch.nn.AlphaDropout: ((), {}), torch.nn.AvgPool1d: ((3,), {}), torch.nn.AvgPool2d: ((3,), {}), torch.nn.AvgPool3d: ((3,), {}), torch.nn.BCELoss: ((), {}), torch.nn.BCEWithLogitsLoss: ((), {}), torch.nn.BatchNorm1d: ((5,), {}), torch.nn.BatchNorm2d: ((5,), {}), torch.nn.BatchNorm3d: ((5,), {}), torch.nn.Bilinear: ((2, 3, 4), {}), torch.nn.CELU: ((), {}), torch.nn.CTCLoss: ((), {}), torch.nn.ChannelShuffle: ((4,), {}), torch.nn.ConstantPad1d: ((2, 3.5), {}), torch.nn.ConstantPad2d: ((2, 3.5), {}), torch.nn.ConstantPad3d: ((2, 3.5), {}), torch.nn.Conv1d: ((3, 3, 3), {}), torch.nn.Conv2d: ((3, 3, 3), {}), torch.nn.Conv3d: ((3, 3, 3), {}), torch.nn.ConvTranspose1d: ((3, 3, 3), {}), torch.nn.ConvTranspose2d: ((3, 3, 3), {}), torch.nn.ConvTranspose3d: ((3, 3, 3), {}), torch.nn.CosineEmbeddingLoss: ((), {}), torch.nn.CosineSimilarity: ((), {}), torch.nn.CrossEntropyLoss: ((), {}), torch.nn.CrossMapLRN2d: ((5,), {}), torch.nn.Dropout2d: ((), {}), torch.nn.Dropout3d: ((), {}), torch.nn.Dropout: ((), {}), torch.nn.ELU: ((), {}), torch.nn.Embedding: ((10, 5), {}), torch.nn.EmbeddingBag: ((10, 5), {}), torch.nn.FeatureAlphaDropout: ((), {}), torch.nn.Flatten: ((), {}), torch.nn.Fold: ((5, 2), {}), torch.nn.FractionalMaxPool2d: ((5, 2), {}), torch.nn.FractionalMaxPool3d: ((5, 2), {}), torch.nn.GELU: ((), {}), torch.nn.GLU: ((), {}), torch.nn.GRU: ((5, 10), {}), torch.nn.GRUCell: ((5, 10), {}), torch.nn.GaussianNLLLoss: ((), {}), torch.nn.GroupNorm: ((3, 6, 1e-5, True), {}), torch.nn.Hardshrink: ((), {}), torch.nn.Hardsigmoid: ((), {}), torch.nn.Hardswish: ((), {}), torch.nn.Hardtanh: ((), {}), torch.nn.HingeEmbeddingLoss: ((), {}), torch.nn.HuberLoss: ((), {}), torch.nn.Identity: ((), {}), torch.nn.InstanceNorm1d: ((5, 1e-5, 0.1, True), {}), torch.nn.InstanceNorm2d: ((5, 1e-5, 0.1, True), {}), torch.nn.InstanceNorm3d: ((5, 1e-5, 0.1, True), {}), torch.nn.KLDivLoss: ((), {}), torch.nn.L1Loss: ((), {}), torch.nn.LPPool1d: ((2, 3), {}), torch.nn.LPPool2d: ((2, 3), {}), torch.nn.LSTM: ((5, 10), {}), torch.nn.LSTMCell: ((5, 10), {}), torch.nn.LayerNorm: ((2,), {}), torch.nn.LazyBatchNorm1d: ((), {}), torch.nn.LazyBatchNorm2d: ((), {}), torch.nn.LazyBatchNorm3d: ((), {}), torch.nn.LazyConv1d: ((5, 2), {}), torch.nn.LazyConv2d: ((5, 2), {}), torch.nn.LazyConv3d: ((5, 2), {}), torch.nn.LazyConvTranspose1d: ((5, 2), {}), torch.nn.LazyConvTranspose2d: ((5, 2), {}), torch.nn.LazyConvTranspose3d: ((5, 2), {}), torch.nn.LazyLinear: ((5,), {}), torch.nn.LeakyReLU: ((), {}), torch.nn.Linear: ((10, 5), {}), torch.nn.LocalResponseNorm: ((2,), {}), torch.nn.LogSigmoid: ((), {}), torch.nn.LogSoftmax: ((), {}), torch.nn.MSELoss: ((), {}), torch.nn.MarginRankingLoss: ((), {}), torch.nn.MaxPool1d: ((3,), {}), torch.nn.MaxPool2d: ((3,), {}), torch.nn.MaxPool3d: ((3,), {}), torch.nn.MaxUnpool1d: ((5,), {}), torch.nn.MaxUnpool2d: ((5,), {}), torch.nn.MaxUnpool3d: ((5,), {}), torch.nn.ModuleDict: ((), {}), torch.nn.ModuleList: ((), {}), torch.nn.MultiLabelMarginLoss: ((), {}), torch.nn.MultiLabelSoftMarginLoss: ((), {}), torch.nn.MultiMarginLoss: ((), {}), torch.nn.MultiheadAttention: ((100, 2), {}), torch.nn.NLLLoss2d: ((), {}), torch.nn.NLLLoss: ((), {}), torch.nn.PReLU: ((), {}), torch.nn.PairwiseDistance: ((), {}), torch.nn.ParameterDict: ((), {}), torch.nn.ParameterList: ((), {}), torch.nn.PixelShuffle: ((2,), {}), torch.nn.PixelUnshuffle: ((2,), {}), torch.nn.PoissonNLLLoss: ((), {}), torch.nn.RNN: ((5, 10), {}), torch.nn.RNNBase: (('LSTM', 5, 10), {}), torch.nn.RNNCell: ((5, 10), {}), torch.nn.RNNCellBase: ((5, 10, True, 2), {}), torch.nn.RReLU: ((), {}), torch.nn.ReLU6: ((), {}), torch.nn.ReLU: ((), {}), torch.nn.ReflectionPad1d: ((2,), {}), torch.nn.ReflectionPad2d: ((2,), {}), torch.nn.ReplicationPad1d: ((2,), {}), torch.nn.ReplicationPad2d: ((2,), {}), torch.nn.ReplicationPad3d: ((2,), {}), torch.nn.SELU: ((), {}), torch.nn.Sequential: ((), {}), torch.nn.SiLU: ((), {}), torch.nn.Sigmoid: ((), {}), torch.nn.SmoothL1Loss: ((), {}), torch.nn.SoftMarginLoss: ((), {}), torch.nn.Softmax2d: ((), {}), torch.nn.Softmax: ((), {}), torch.nn.Softmin: ((), {}), torch.nn.Softplus: ((), {}), torch.nn.Softshrink: ((), {}), torch.nn.Softsign: ((), {}), torch.nn.SyncBatchNorm: ((5,), {}), torch.nn.Tanh: ((), {}), torch.nn.Tanhshrink: ((), {}), torch.nn.Threshold: ((0.1, 20), {}), torch.nn.Transformer: ((), {}), torch.nn.TransformerDecoder: ((torch.nn.TransformerDecoderLayer, 3), {}), torch.nn.TransformerDecoderLayer: ((10, 2), {}), torch.nn.TransformerEncoder: ((torch.nn.TransformerEncoderLayer, 3), {}), torch.nn.TransformerEncoderLayer: ((10, 2), {}), torch.nn.TripletMarginLoss: ((), {}), torch.nn.TripletMarginWithDistanceLoss: ((), {}), torch.nn.Unflatten: ((1, (2, 5, 5)), {}), torch.nn.Unfold: ((3,), {}), torch.nn.Upsample: ((), {}), torch.nn.UpsamplingBilinear2d: ((), {}), torch.nn.UpsamplingNearest2d: ((), {}), torch.nn.ZeroPad2d: ((0,), {}), torch.nn.qat.Conv2d: ((3, 3, 3), { 'qconfig': torch.quantization.default_qconfig, }), torch.nn.qat.Conv3d: ((3, 3, 3), { 'qconfig': torch.quantization.default_qconfig, }), torch.nn.qat.Linear: ((5, 2), { 'qconfig': torch.quantization.default_qconfig, }), torch.nn.quantizable.LSTM: ((5, 6), {}), torch.nn.quantizable.LSTMCell: ((5, 6), {}), torch.nn.quantizable.MultiheadAttention: ((10, 2), {}), torch.nn.quantized.BatchNorm2d: ((2,), {}), torch.nn.quantized.BatchNorm3d: ((2,), {}), torch.nn.quantized.Conv1d: ((3, 3, 3), {}), torch.nn.quantized.Conv2d: ((3, 3, 3), {}), torch.nn.quantized.Conv3d: ((3, 3, 3), {}), torch.nn.quantized.ConvTranspose1d: ((3, 3, 3), {}), torch.nn.quantized.ConvTranspose2d: ((3, 3, 3), {}), torch.nn.quantized.ConvTranspose3d: ((16, 33, (3, 3, 5)), { 'stride': (2, 1, 1), 'padding': (4, 2, 2), 'output_padding': (2, 2, 2), 'dilation': (1, 1, 1), }), torch.nn.quantized.DeQuantize: ((), {}), torch.nn.quantized.ELU: ((0.01, 0), {}), torch.nn.quantized.Embedding: ((10, 3), { 'factory_kwargs': {}, }), torch.nn.quantized.EmbeddingBag: ((10, 3), { 'factory_kwargs': {}, }), torch.nn.quantized.GroupNorm: ((2, 3, torch.nn.Parameter(torch.tensor(2.)), torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}), torch.nn.quantized.Hardswish: ((0.1, 0,), {}), torch.nn.quantized.InstanceNorm1d: ((2, torch.nn.Parameter(torch.tensor(2.)), torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}), torch.nn.quantized.InstanceNorm2d: ((2, torch.nn.Parameter(torch.tensor(2.)), torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}), torch.nn.quantized.InstanceNorm3d: ((2, torch.nn.Parameter(torch.tensor(2.)), torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}), torch.nn.quantized.LayerNorm: ((2, torch.nn.Parameter(torch.tensor(2.)), torch.nn.Parameter(torch.tensor(2.)), 0.1, 0), {}), torch.nn.quantized.LeakyReLU: ((0.01, 0), {}), torch.nn.quantized.Linear: ((5, 2), { 'factory_kwargs': {}, }), torch.nn.quantized.MaxPool2d: ((3,), {}), torch.nn.quantized.Quantize: ((0.1, 0), { 'dtype': torch.int16, 'factory_kwargs': {}, }), torch.nn.quantized.ReLU6: ((), {}), torch.nn.quantized.Sigmoid: ((0.1, 0), {}), torch.nn.quantized.FloatFunctional: ((), {}), torch.nn.quantized.FXFloatFunctional: ((), {}), torch.nn.quantized.QFunctional: ((), {}), } # Instantiates the given class with the given args, kwargs, optionally on a given device. def instantiate_class(cls, args, kwargs, extra_kwargs): return cls(*args, **kwargs) if extra_kwargs is None else cls(*args, **kwargs, **extra_kwargs) # Returns a function that calls the real implementation of a method # in addition to passing args to a mock object. def mock_wrapper(method): mock = MagicMock() def wrapper(self, *args, **kwargs): mock(*args, **kwargs) return method(self, *args, **kwargs) wrapper.mock = mock return wrapper # Returns a set of args / kwargs that can be used to construct the module. def get_example_args(module_cls, constructor_arg_db, extra_kwargs=None): assert module_cls in constructor_arg_db, \ f"No entry for {module_cls} in the constructor arg DB. Please add it to pass these tests." args, kwargs = constructor_arg_db[module_cls] extra_kwargs = {} if extra_kwargs is None else extra_kwargs # Recursively instantiate args / kwargs that are class objects. args = [instantiate_class(arg, *get_example_args(arg, constructor_arg_db), extra_kwargs=extra_kwargs) if inspect.isclass(arg) else torch.nn.Parameter(arg.to(**extra_kwargs)) if isinstance(arg, torch.nn.Parameter) else arg for arg in args] kwargs = {k: instantiate_class(v, *get_example_args(v, constructor_arg_db), extra_kwargs=extra_kwargs) if inspect.isclass(v) else torch.nn.Parameter(v.to(*extra_kwargs)) if isinstance(v, torch.nn.Parameter) else v for k, v in kwargs.items()} kwargs.update(extra_kwargs) return args, kwargs def generate_test_func(test_cls, module_cls, constructor_arg_db, verify_kwargs=True, module_is_lazy=False, check_nonexistent_arg=True): # Generate a function for testing the given module. @dtypes(*floating_types()) def run_test(test_cls, device, dtype, module_cls=module_cls): # Check if this module creates parameters or registers buffers. # The mock magic here passes through to the real Parameter / register_buffer # logic and is only used to check for calls. args, kwargs = get_example_args(module_cls, constructor_arg_db) # Some modules need to pass factory_kwargs so as not to conflict with existing args such as dtype. module_needs_factory_kwargs = 'factory_kwargs' in kwargs if module_needs_factory_kwargs: del kwargs['factory_kwargs'] extra_kwargs = { 'factory_kwargs': { 'device': device, 'dtype': dtype, } } else: extra_kwargs = { 'device': device, 'dtype': dtype, } parameter_new = mock_wrapper(torch.nn.Parameter.__new__) with patch.object(torch.nn.Parameter, '__new__', parameter_new): register_buffer = mock_wrapper(torch.nn.Module.register_buffer) with patch.object(torch.nn.Module, 'register_buffer', register_buffer): m = module_cls(*args, **kwargs) module_creates_params_or_buffers = parameter_new.mock.called or register_buffer.mock.called # == Verify factory kwargs are supported. == if verify_kwargs and module_creates_params_or_buffers: args, kwargs = get_example_args(module_cls, constructor_arg_db, extra_kwargs=extra_kwargs) if module_is_lazy: # Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers. uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__) with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new): uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__) with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new): m = module_cls(*args, **kwargs) uninit_param_new.mock.assert_has_calls( [mock.call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls]) uninit_buffer_new.mock.assert_has_calls( [mock.call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls]) else: # Check device placement and dtype for parameters and buffers. # Only verify floating point dtypes since that's what the kwarg applies to. # Note that dtype verification is also skipped if the module requires factory_kwargs. m = module_cls(*args, **kwargs) for name, param in m.named_parameters(): test_cls.assertEqual( str(param.device), device, f'Parameter {name} is on {param.device.type} instead of the expected device {device}') if param.dtype.is_floating_point and not module_needs_factory_kwargs: test_cls.assertEqual( param.dtype, dtype, f'Parameter {name} is of dtype {param.dtype} instead of the expected dtype {dtype}') for name, buffer in m.named_buffers(): test_cls.assertEqual( str(buffer.device), device, f'Buffer {name} is on {buffer.device.type} instead of the expected device {device}') if buffer.dtype.is_floating_point and not module_needs_factory_kwargs: test_cls.assertEqual( buffer.dtype, dtype, f'Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}') # == Verify passing a nonexistent arg errors out. == if check_nonexistent_arg: with test_cls.assertRaises(TypeError): m = module_cls(*args, **kwargs, nonexistent_arg='foo') return run_test def generate_tests(test_cls, constructor_arg_db): # test all modules underneath these namespaces... NAMESPACES = [ torch.nn, torch.nn.qat, torch.nn.quantizable, torch.nn.quantized, ] # ...except these MODULES_TO_SKIP = { torch.nn.Module, torch.nn.Container, # deprecated torch.nn.NLLLoss2d, # deprecated torch.nn.quantized._ConvNd, # base class in __all__ for some reason # TODO: Remove these 2 from this list once the ASan issue is fixed. # See https://github.com/pytorch/pytorch/issues/55396 torch.nn.quantized.Embedding, torch.nn.quantized.EmbeddingBag, } # no need to support kwargs for these modules even though # they have parameters / buffers because they are passed in # already instantiated MODULES_WITHOUT_KWARGS_SUPPORT = { torch.nn.BCELoss, torch.nn.BCEWithLogitsLoss, torch.nn.CrossEntropyLoss, torch.nn.FractionalMaxPool2d, torch.nn.FractionalMaxPool3d, torch.nn.MultiLabelSoftMarginLoss, torch.nn.MultiMarginLoss, torch.nn.NLLLoss, torch.nn.TransformerDecoder, torch.nn.TransformerEncoder, } # modules that supported kwargs before MODULES_WITH_PREVIOUS_KWARGS = { torch.nn.Identity, } # lazy modules don't instantiate parameters right away LAZY_MODULES = { torch.nn.LazyBatchNorm1d, torch.nn.LazyBatchNorm2d, torch.nn.LazyBatchNorm3d, torch.nn.LazyConv1d, torch.nn.LazyConv2d, torch.nn.LazyConv3d, torch.nn.LazyConvTranspose1d, torch.nn.LazyConvTranspose2d, torch.nn.LazyConvTranspose3d, torch.nn.LazyConvTranspose3d, torch.nn.LazyLinear, } # these modules requires FBGEMM backend to instantiate MODULES_THAT_REQUIRE_FBGEMM = { torch.nn.quantized.Conv1d, torch.nn.quantized.Conv2d, torch.nn.quantized.Conv3d, torch.nn.quantized.ConvTranspose1d, torch.nn.quantized.ConvTranspose2d, torch.nn.quantized.ConvTranspose3d, torch.nn.quantized.Linear, } for namespace in NAMESPACES: # the "nn" in "torch.nn" namespace_basename = namespace.__name__.split('.')[-1] for module_name in namespace.modules.__all__: # class object for this module (e.g. torch.nn.Linear) module_cls = getattr(namespace.modules, module_name) if module_cls in MODULES_TO_SKIP: continue verify_kwargs = module_cls not in MODULES_WITHOUT_KWARGS_SUPPORT module_is_lazy = module_cls in LAZY_MODULES check_nonexistent_arg = module_cls not in MODULES_WITH_PREVIOUS_KWARGS # Generate a function for testing this module and setattr it onto the test class. run_test = generate_test_func(test_cls, module_cls, constructor_arg_db, verify_kwargs=verify_kwargs, module_is_lazy=module_is_lazy, check_nonexistent_arg=check_nonexistent_arg) test_name = f'test_{namespace_basename}_{module_name}' if module_cls in MODULES_THAT_REQUIRE_FBGEMM: run_test = skipIfNoFBGEMM(run_test) setattr(TestModuleInit, test_name, run_test) class TestModuleInit(TestCase): _ignore_not_implemented_error = False generate_tests(TestModuleInit, build_constructor_arg_db()) instantiate_device_type_tests(TestModuleInit, globals()) if __name__ == '__main__': run_tests()