mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx2trt] Fix dummy weight initialization in conv1d converter (#78402)
Summary: att, currently it errors out with the following error: ``` ---> 72 dummy_weight = trt.Weights(weight_shape) 73 layer = network.add_convolution_nd( 74 input=input_val, TypeError: __init__(): incompatible constructor arguments. The following argument types are supported: 1. tensorrt.tensorrt.Weights(type: tensorrt.tensorrt.DataType = <DataType.FLOAT: 0>) 2. tensorrt.tensorrt.Weights(a: numpy.ndarray) ``` full error: https://www.internalfb.com/phabricator/paste/view/P503598381 we need to pass arond a numpy ndarray instead of a shape here. and support conv1d in backend_config_dict for tensorrt Test Plan: ``` buck test mode/opt deeplearning/trt/fx2trt_oss/test/converters:test_convolution ``` ``` buck test mode/opt deeplearning/trt/fx2trt_oss/test/quant:test_quant_trt ``` Differential Revision: D36721313 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78402 Approved by: https://github.com/842974287
This commit is contained in:
committed by
PyTorch MergeBot
parent
299fbbccec
commit
85f308275e
@ -262,13 +262,12 @@ def _get_linear_configs(dtype_configs):
|
||||
})
|
||||
return linear_configs
|
||||
|
||||
def _get_conv_configs():
|
||||
def _get_conv_configs(dtype_configs):
|
||||
"""
|
||||
Return all configs related to conv modules and ops.
|
||||
"""
|
||||
conv_configs = []
|
||||
observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
|
||||
dtype_configs = [weighted_op_int8_dtype_config]
|
||||
for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]:
|
||||
|
||||
# (1) Single conv modules/functions
|
||||
@ -685,6 +684,7 @@ def _get_embedding_op_configs():
|
||||
|
||||
def get_native_backend_config_dict():
|
||||
""" Get backend_config_dict for PyTorch Native backend (fbgemm/qnnpack). """
|
||||
conv_dtype_configs = [weighted_op_int8_dtype_config]
|
||||
linear_dtype_configs = [
|
||||
weighted_op_int8_dtype_config,
|
||||
default_dynamic_int8_dtype_config,
|
||||
@ -706,7 +706,7 @@ def get_native_backend_config_dict():
|
||||
"configs": [
|
||||
*_DEFAULT_OP_INT8_CONFIGS,
|
||||
*_get_linear_configs(linear_dtype_configs),
|
||||
*_get_conv_configs(),
|
||||
*_get_conv_configs(conv_dtype_configs),
|
||||
*_get_binary_op_configs(binary_op_dtype_configs),
|
||||
*_get_fixed_qparams_op_configs(),
|
||||
_CAT_CONFIG,
|
||||
|
@ -1,15 +1,11 @@
|
||||
import torch
|
||||
from .observation_type import ObservationType
|
||||
import torch.nn.qat as nnqat
|
||||
import torch.nn.intrinsic as nni
|
||||
import torch.nn.intrinsic.qat as nniqat
|
||||
# TODO: maybe refactor this to a separate util function
|
||||
from .native import _get_binary_op_configs
|
||||
from .native import _get_linear_configs
|
||||
from .native import _get_conv_configs
|
||||
from .native import _get_share_qparams_op_configs
|
||||
|
||||
from ..fuser_method_mappings import reverse_sequential_wrapper2
|
||||
|
||||
def get_tensorrt_backend_config_dict():
|
||||
""" Get the backend config dictionary for tensorrt backend
|
||||
NOTE: Current api will change in the future, it's just to unblock experimentation for
|
||||
@ -34,132 +30,6 @@ def get_tensorrt_backend_config_dict():
|
||||
"output_dtype": torch.qint8,
|
||||
}
|
||||
|
||||
# operator (module/functional/torch ops) configs
|
||||
linear_qat_config = {
|
||||
"pattern": nnqat.Linear,
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"root_module": torch.nn.Linear,
|
||||
"reference_quantized_module_for_root": torch.nn.quantized._reference.Linear,
|
||||
}
|
||||
# TODO: maybe make "pattern" to be a list of patterns
|
||||
# TODO: current patterns are the ones after fusion, we will want to expose fusion
|
||||
# here as well in the future, maybe we need to
|
||||
linear_relu_mm_config = {
|
||||
"pattern": (torch.nn.ReLU, torch.nn.Linear),
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"fuser_method": reverse_sequential_wrapper2(nni.LinearReLU),
|
||||
"fused_module": nni.LinearReLU,
|
||||
}
|
||||
linear_relu_mf_config = {
|
||||
"pattern": (torch.nn.functional.relu, torch.nn.Linear),
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"fuser_method": reverse_sequential_wrapper2(nni.LinearReLU),
|
||||
"fused_module": nni.LinearReLU,
|
||||
}
|
||||
|
||||
linear_relu_fused_config = {
|
||||
"pattern": nni.LinearReLU,
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"root_module": torch.nn.Linear,
|
||||
"reference_quantized_module_for_root": torch.nn.quantized._reference.Linear,
|
||||
"qat_module": nniqat.LinearReLU,
|
||||
}
|
||||
linear_relu_qat_config = {
|
||||
"pattern": nniqat.LinearReLU,
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"root_module": torch.nn.Linear,
|
||||
"reference_quantized_module_for_root": torch.nn.quantized._reference.Linear,
|
||||
}
|
||||
conv_module_config = {
|
||||
"pattern": torch.nn.Conv2d,
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"root_module": torch.nn.Conv2d,
|
||||
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
|
||||
"qat_module": nnqat.Conv2d,
|
||||
}
|
||||
conv_qat_config = {
|
||||
"pattern": nnqat.Conv2d,
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"root_module": torch.nn.Conv2d,
|
||||
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
|
||||
}
|
||||
conv1d_relu_fused_config = {
|
||||
"pattern": nni.ConvReLU1d,
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"root_module": torch.nn.Conv1d,
|
||||
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv1d,
|
||||
}
|
||||
conv2d_relu_fused_config = {
|
||||
"pattern": nni.ConvReLU2d,
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"root_module": torch.nn.Conv2d,
|
||||
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
|
||||
"qat_module": nniqat.ConvReLU2d,
|
||||
}
|
||||
conv2d_relu_qat_config = {
|
||||
"pattern": nniqat.ConvReLU2d,
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"root_module": torch.nn.Conv2d,
|
||||
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
|
||||
}
|
||||
conv3d_relu_fused_config = {
|
||||
"pattern": nni.ConvReLU3d,
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"root_module": torch.nn.Conv3d,
|
||||
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv3d,
|
||||
"qat_module": nniqat.ConvReLU3d,
|
||||
}
|
||||
conv2d_relu_mf_config = {
|
||||
"pattern": (torch.nn.functional.relu, torch.nn.Conv2d),
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"fuser_method": reverse_sequential_wrapper2(nni.ConvReLU2d),
|
||||
"fused_module": nni.ConvReLU2d,
|
||||
}
|
||||
conv2d_relu_mm_config = {
|
||||
"pattern": (torch.nn.ReLU, torch.nn.Conv2d),
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
"dtype_configs": [
|
||||
weighted_op_qint8_dtype_config,
|
||||
],
|
||||
"fuser_method": reverse_sequential_wrapper2(nni.ConvReLU2d),
|
||||
"fused_module": nni.ConvReLU2d,
|
||||
}
|
||||
addmm_config = {
|
||||
"pattern": torch.addmm,
|
||||
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
|
||||
@ -180,6 +50,9 @@ def get_tensorrt_backend_config_dict():
|
||||
non_weighted_op_qint8_dtype_config,
|
||||
]
|
||||
}
|
||||
conv_dtype_configs = [
|
||||
weighted_op_qint8_dtype_config,
|
||||
]
|
||||
linear_dtype_configs = [
|
||||
weighted_op_qint8_dtype_config,
|
||||
]
|
||||
@ -193,21 +66,9 @@ def get_tensorrt_backend_config_dict():
|
||||
# optional
|
||||
"name": "tensorrt",
|
||||
"configs": [
|
||||
linear_qat_config,
|
||||
linear_relu_fused_config,
|
||||
linear_relu_qat_config,
|
||||
linear_relu_mm_config,
|
||||
linear_relu_mf_config,
|
||||
conv_module_config,
|
||||
conv_qat_config,
|
||||
# conv1d is not supported in fx2trt
|
||||
# conv1d_relu_fused_config,
|
||||
conv2d_relu_fused_config,
|
||||
conv2d_relu_qat_config,
|
||||
conv2d_relu_mf_config,
|
||||
conv2d_relu_mm_config,
|
||||
# conv3d is not supported in fx2trt
|
||||
# conv3d_relu_fused_config,
|
||||
# there might be things not supported in fx2trt, but it will error out
|
||||
# during fx2trt conversion and can support them after that
|
||||
*_get_conv_configs(conv_dtype_configs),
|
||||
addmm_config,
|
||||
cat_config,
|
||||
*_get_linear_configs(linear_dtype_configs),
|
||||
|
Reference in New Issue
Block a user