[fx2trt] Fix dummy weight initialization in conv1d converter (#78402)

Summary:
att, currently it errors out with the following error:
```
---> 72         dummy_weight = trt.Weights(weight_shape)
     73         layer = network.add_convolution_nd(
     74             input=input_val,
TypeError: __init__(): incompatible constructor arguments. The following argument types are supported:
    1. tensorrt.tensorrt.Weights(type: tensorrt.tensorrt.DataType = <DataType.FLOAT: 0>)
    2. tensorrt.tensorrt.Weights(a: numpy.ndarray)
```
full error: https://www.internalfb.com/phabricator/paste/view/P503598381
we need to pass arond a numpy ndarray instead of a shape here.

and support conv1d in backend_config_dict for tensorrt

Test Plan:
```
buck test mode/opt deeplearning/trt/fx2trt_oss/test/converters:test_convolution
```

```
buck test mode/opt deeplearning/trt/fx2trt_oss/test/quant:test_quant_trt
```

Differential Revision: D36721313

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78402
Approved by: https://github.com/842974287
This commit is contained in:
Jerry Zhang
2022-05-27 04:48:45 +00:00
committed by PyTorch MergeBot
parent 299fbbccec
commit 85f308275e
2 changed files with 10 additions and 149 deletions

View File

@ -262,13 +262,12 @@ def _get_linear_configs(dtype_configs):
})
return linear_configs
def _get_conv_configs():
def _get_conv_configs(dtype_configs):
"""
Return all configs related to conv modules and ops.
"""
conv_configs = []
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
dtype_configs = [weighted_op_int8_dtype_config]
for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]:
# (1) Single conv modules/functions
@ -685,6 +684,7 @@ def _get_embedding_op_configs():
def get_native_backend_config_dict():
""" Get backend_config_dict for PyTorch Native backend (fbgemm/qnnpack). """
conv_dtype_configs = [weighted_op_int8_dtype_config]
linear_dtype_configs = [
weighted_op_int8_dtype_config,
default_dynamic_int8_dtype_config,
@ -706,7 +706,7 @@ def get_native_backend_config_dict():
"configs": [
*_DEFAULT_OP_INT8_CONFIGS,
*_get_linear_configs(linear_dtype_configs),
*_get_conv_configs(),
*_get_conv_configs(conv_dtype_configs),
*_get_binary_op_configs(binary_op_dtype_configs),
*_get_fixed_qparams_op_configs(),
_CAT_CONFIG,

View File

@ -1,15 +1,11 @@
import torch
from .observation_type import ObservationType
import torch.nn.qat as nnqat
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.qat as nniqat
# TODO: maybe refactor this to a separate util function
from .native import _get_binary_op_configs
from .native import _get_linear_configs
from .native import _get_conv_configs
from .native import _get_share_qparams_op_configs
from ..fuser_method_mappings import reverse_sequential_wrapper2
def get_tensorrt_backend_config_dict():
""" Get the backend config dictionary for tensorrt backend
NOTE: Current api will change in the future, it's just to unblock experimentation for
@ -34,132 +30,6 @@ def get_tensorrt_backend_config_dict():
"output_dtype": torch.qint8,
}
# operator (module/functional/torch ops) configs
linear_qat_config = {
"pattern": nnqat.Linear,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"root_module": torch.nn.Linear,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Linear,
}
# TODO: maybe make "pattern" to be a list of patterns
# TODO: current patterns are the ones after fusion, we will want to expose fusion
# here as well in the future, maybe we need to
linear_relu_mm_config = {
"pattern": (torch.nn.ReLU, torch.nn.Linear),
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"fuser_method": reverse_sequential_wrapper2(nni.LinearReLU),
"fused_module": nni.LinearReLU,
}
linear_relu_mf_config = {
"pattern": (torch.nn.functional.relu, torch.nn.Linear),
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"fuser_method": reverse_sequential_wrapper2(nni.LinearReLU),
"fused_module": nni.LinearReLU,
}
linear_relu_fused_config = {
"pattern": nni.LinearReLU,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"root_module": torch.nn.Linear,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Linear,
"qat_module": nniqat.LinearReLU,
}
linear_relu_qat_config = {
"pattern": nniqat.LinearReLU,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"root_module": torch.nn.Linear,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Linear,
}
conv_module_config = {
"pattern": torch.nn.Conv2d,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"root_module": torch.nn.Conv2d,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
"qat_module": nnqat.Conv2d,
}
conv_qat_config = {
"pattern": nnqat.Conv2d,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"root_module": torch.nn.Conv2d,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
}
conv1d_relu_fused_config = {
"pattern": nni.ConvReLU1d,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"root_module": torch.nn.Conv1d,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv1d,
}
conv2d_relu_fused_config = {
"pattern": nni.ConvReLU2d,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"root_module": torch.nn.Conv2d,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
"qat_module": nniqat.ConvReLU2d,
}
conv2d_relu_qat_config = {
"pattern": nniqat.ConvReLU2d,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"root_module": torch.nn.Conv2d,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
}
conv3d_relu_fused_config = {
"pattern": nni.ConvReLU3d,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"root_module": torch.nn.Conv3d,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv3d,
"qat_module": nniqat.ConvReLU3d,
}
conv2d_relu_mf_config = {
"pattern": (torch.nn.functional.relu, torch.nn.Conv2d),
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"fuser_method": reverse_sequential_wrapper2(nni.ConvReLU2d),
"fused_module": nni.ConvReLU2d,
}
conv2d_relu_mm_config = {
"pattern": (torch.nn.ReLU, torch.nn.Conv2d),
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"fuser_method": reverse_sequential_wrapper2(nni.ConvReLU2d),
"fused_module": nni.ConvReLU2d,
}
addmm_config = {
"pattern": torch.addmm,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
@ -180,6 +50,9 @@ def get_tensorrt_backend_config_dict():
non_weighted_op_qint8_dtype_config,
]
}
conv_dtype_configs = [
weighted_op_qint8_dtype_config,
]
linear_dtype_configs = [
weighted_op_qint8_dtype_config,
]
@ -193,21 +66,9 @@ def get_tensorrt_backend_config_dict():
# optional
"name": "tensorrt",
"configs": [
linear_qat_config,
linear_relu_fused_config,
linear_relu_qat_config,
linear_relu_mm_config,
linear_relu_mf_config,
conv_module_config,
conv_qat_config,
# conv1d is not supported in fx2trt
# conv1d_relu_fused_config,
conv2d_relu_fused_config,
conv2d_relu_qat_config,
conv2d_relu_mf_config,
conv2d_relu_mm_config,
# conv3d is not supported in fx2trt
# conv3d_relu_fused_config,
# there might be things not supported in fx2trt, but it will error out
# during fx2trt conversion and can support them after that
*_get_conv_configs(conv_dtype_configs),
addmm_config,
cat_config,
*_get_linear_configs(linear_dtype_configs),