mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037 Approved by: https://github.com/mlazos
286 lines
11 KiB
Python
286 lines
11 KiB
Python
from collections.abc import Callable
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.ao.nn.intrinsic as nni
|
|
import torch.ao.nn.intrinsic.qat as nniqat
|
|
import torch.ao.nn.intrinsic.quantized as nniq
|
|
import torch.ao.nn.qat as nnqat
|
|
import torch.ao.nn.quantized as nnq
|
|
import torch.ao.nn.quantized.dynamic as nnqd
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.fx import GraphModule
|
|
from torch.fx.graph import Node
|
|
|
|
from .ns_types import NSSingleResultType, NSSingleResultValuesType
|
|
from .utils import get_target_type_str, getattr_from_fqn, return_first_non_observer_node
|
|
|
|
|
|
toq = torch.ops.quantized
|
|
|
|
|
|
def mod_weight_detach(mod: nn.Module) -> torch.Tensor:
|
|
return mod.weight.detach() # type: ignore[operator]
|
|
|
|
|
|
def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor:
|
|
return mod[0].weight.detach() # type: ignore[index]
|
|
|
|
|
|
def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor:
|
|
return mod._weight_bias()[0] # type: ignore[operator]
|
|
|
|
|
|
def get_lstm_weight(mod: nn.Module) -> list[torch.Tensor]:
|
|
res = []
|
|
for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type]
|
|
if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
|
|
param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr]
|
|
res.append(param_value)
|
|
return res
|
|
|
|
|
|
def get_qlstm_weight(mod: nn.Module) -> list[torch.Tensor]:
|
|
res = []
|
|
for weight_value in mod._all_weight_values: # type: ignore[union-attr]
|
|
res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
|
|
res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
|
|
return res
|
|
|
|
|
|
def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
|
|
if isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
|
return mod.weight.detach()
|
|
elif isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)):
|
|
return mod[0].weight.detach() # type: ignore[operator]
|
|
else:
|
|
return mod._weight_bias()[0] # type: ignore[operator]
|
|
|
|
|
|
def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor:
|
|
if isinstance(mod, nn.Linear):
|
|
return mod.weight.detach()
|
|
elif isinstance(mod, nni.LinearReLU):
|
|
return mod[0].weight.detach() # type: ignore[operator]
|
|
else:
|
|
return mod._weight_bias()[0] # type: ignore[operator]
|
|
|
|
|
|
def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
|
|
# TODO(future PR): make more generic, handle everything
|
|
if isinstance(mod, nn.LSTM):
|
|
res = []
|
|
for idx, param_name in enumerate(mod._flat_weights_names):
|
|
if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
|
|
param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr]
|
|
res.append(param_value)
|
|
return res
|
|
else:
|
|
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
|
|
res = []
|
|
for weight_value in mod._all_weight_values:
|
|
res.append(
|
|
weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0] # type: ignore[index]
|
|
)
|
|
res.append(
|
|
weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0] # type: ignore[index]
|
|
)
|
|
return res
|
|
|
|
|
|
def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
|
# traverse backwards from the weight arg, accounting for any observers
|
|
weight_arg_node = node.args[1]
|
|
assert isinstance(weight_arg_node, Node)
|
|
weight_node = return_first_non_observer_node(weight_arg_node, gm)
|
|
assert isinstance(weight_node, Node)
|
|
assert weight_node.op == "get_attr"
|
|
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
|
return weight.detach()
|
|
|
|
|
|
def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
|
# qconv state is arg 1
|
|
qconv_state_node = node.args[1]
|
|
assert isinstance(qconv_state_node, Node)
|
|
assert qconv_state_node.op == "get_attr"
|
|
qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type]
|
|
return qconv_state_obj.weight()
|
|
|
|
|
|
def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
|
# traverse backwards from the weight arg, accounting for any observers
|
|
# supported patterns:
|
|
# weight -> obs -> linear
|
|
# weight -> to(torch.float16) -> dequantize -> linear
|
|
linear_second_arg = node.args[1]
|
|
assert isinstance(linear_second_arg, Node)
|
|
|
|
if linear_second_arg.op == "call_module":
|
|
# weight -> obs -> linear
|
|
weight_arg_node = node.args[1]
|
|
assert isinstance(weight_arg_node, Node)
|
|
weight_node = weight_arg_node.args[0]
|
|
assert isinstance(weight_node, Node)
|
|
assert weight_node.op == "get_attr"
|
|
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
|
return weight.detach()
|
|
elif linear_second_arg.op == "call_method":
|
|
# weight -> to(torch.float16) -> dequantize -> linear
|
|
assert linear_second_arg.op == "call_method"
|
|
dequant_node = node.args[1]
|
|
assert isinstance(dequant_node, Node)
|
|
to_fp16_node = dequant_node.args[0]
|
|
assert isinstance(to_fp16_node, Node)
|
|
# extract the dtype, so we can cast to it before returning
|
|
target_dtype = to_fp16_node.args[1]
|
|
weight_node = to_fp16_node.args[0]
|
|
assert isinstance(weight_node, Node)
|
|
assert weight_node.op == "get_attr"
|
|
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
|
# return the weight with fp16 cast
|
|
return weight.detach().to(target_dtype)
|
|
else:
|
|
assert linear_second_arg.op == "get_attr"
|
|
weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type]
|
|
return weight.detach()
|
|
|
|
|
|
def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
|
# packed weight is arg 1
|
|
packed_weight_node = node.args[1]
|
|
assert isinstance(packed_weight_node, Node)
|
|
assert packed_weight_node.op == "get_attr"
|
|
packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type]
|
|
# TODO(future PR): why does packed_weight.unpack() not work?
|
|
(weight, _bias), _name = packed_weight.__getstate__()
|
|
return weight
|
|
|
|
|
|
def get_op_to_type_to_weight_extraction_fn() -> dict[str, dict[Callable, Callable]]:
|
|
op_to_type_to_weight_extraction_fn: dict[str, dict[Callable, Callable]] = {
|
|
"call_module": {
|
|
# Conv1d
|
|
nn.Conv1d: mod_weight_detach,
|
|
nni.ConvReLU1d: mod_0_weight_detach,
|
|
nnq.Conv1d: mod_weight_bias_0,
|
|
nnqat.Conv1d: mod_weight_detach,
|
|
nniqat.ConvBn1d: mod_weight_detach,
|
|
nniqat.ConvBnReLU1d: mod_weight_detach,
|
|
nniqat.ConvReLU1d: mod_weight_detach,
|
|
nniq.ConvReLU1d: mod_weight_bias_0,
|
|
# Conv2d
|
|
nn.Conv2d: mod_weight_detach,
|
|
nni.ConvReLU2d: mod_0_weight_detach,
|
|
nnq.Conv2d: mod_weight_bias_0,
|
|
nnqat.Conv2d: mod_weight_detach,
|
|
nniqat.ConvBn2d: mod_weight_detach,
|
|
nniqat.ConvBnReLU2d: mod_weight_detach,
|
|
nniqat.ConvReLU2d: mod_weight_detach,
|
|
nniq.ConvReLU2d: mod_weight_bias_0,
|
|
# Conv3d
|
|
nn.Conv3d: mod_weight_detach,
|
|
nni.ConvReLU3d: mod_0_weight_detach,
|
|
nnq.Conv3d: mod_weight_bias_0,
|
|
nnqat.Conv3d: mod_weight_detach,
|
|
nniqat.ConvBn3d: mod_weight_detach,
|
|
nniqat.ConvBnReLU3d: mod_weight_detach,
|
|
nniqat.ConvReLU3d: mod_weight_detach,
|
|
nniq.ConvReLU3d: mod_weight_bias_0,
|
|
# Linear
|
|
nn.Linear: mod_weight_detach,
|
|
nnq.Linear: mod_weight_bias_0,
|
|
nni.LinearReLU: mod_0_weight_detach,
|
|
nniq.LinearReLU: mod_weight_bias_0,
|
|
nnqat.Linear: mod_weight_detach,
|
|
nnqd.Linear: mod_weight_bias_0,
|
|
nniqat.LinearReLU: mod_weight_detach,
|
|
nniqat.LinearBn1d: mod_weight_detach,
|
|
nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
|
|
# LSTM
|
|
nn.LSTM: get_lstm_weight,
|
|
nnqd.LSTM: get_qlstm_weight,
|
|
},
|
|
"call_function": {
|
|
# Conv
|
|
F.conv1d: get_conv_fun_weight,
|
|
F.conv2d: get_conv_fun_weight,
|
|
F.conv3d: get_conv_fun_weight,
|
|
toq.conv1d: get_qconv_fun_weight,
|
|
toq.conv2d: get_qconv_fun_weight,
|
|
toq.conv3d: get_qconv_fun_weight,
|
|
toq.conv1d_relu: get_qconv_fun_weight,
|
|
toq.conv2d_relu: get_qconv_fun_weight,
|
|
toq.conv3d_relu: get_qconv_fun_weight,
|
|
# Linear
|
|
F.linear: get_linear_fun_weight,
|
|
toq.linear: get_qlinear_fun_weight,
|
|
toq.linear_relu: get_qlinear_fun_weight,
|
|
},
|
|
}
|
|
|
|
return op_to_type_to_weight_extraction_fn
|
|
|
|
|
|
def extract_weight_from_node(
|
|
node: Node,
|
|
gm: GraphModule,
|
|
op_to_type_to_weight_extraction_fn: Optional[
|
|
dict[str, dict[Callable, Callable]]
|
|
] = None,
|
|
) -> Optional[NSSingleResultType]:
|
|
res_type = NSSingleResultValuesType.WEIGHT.value
|
|
|
|
# Not all graphmodules have _node_name_to_scope, so only fill it
|
|
# out if it exists.
|
|
fqn = None
|
|
if hasattr(gm, "_node_name_to_scope"):
|
|
fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index]
|
|
|
|
if op_to_type_to_weight_extraction_fn is None:
|
|
op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()
|
|
|
|
ref_node_type = get_target_type_str(node, gm)
|
|
# for extracting weights, these are always the same
|
|
prev_node_type = ref_node_type
|
|
|
|
if node.op == "call_function":
|
|
function_mapping = op_to_type_to_weight_extraction_fn["call_function"]
|
|
for target_fn_type, weight_extraction_fn in function_mapping.items():
|
|
if node.target == target_fn_type:
|
|
weight = weight_extraction_fn(node, gm)
|
|
return {
|
|
"type": res_type,
|
|
"values": [weight],
|
|
"prev_node_name": node.name,
|
|
"prev_node_target_type": prev_node_type,
|
|
"ref_node_name": node.name,
|
|
"ref_node_target_type": ref_node_type,
|
|
"index_within_arg": 0,
|
|
"index_of_arg": 0,
|
|
"fqn": fqn,
|
|
}
|
|
|
|
elif node.op == "call_module":
|
|
# for call_module, we need to look up the modules to do the type check
|
|
assert isinstance(node.target, str)
|
|
mod = getattr_from_fqn(gm, node.target)
|
|
module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
|
|
for target_mod_type, weight_extraction_fn in module_mapping.items():
|
|
if type(mod) is target_mod_type:
|
|
weight = weight_extraction_fn(mod)
|
|
return {
|
|
"type": res_type,
|
|
"values": [weight],
|
|
"prev_node_name": node.name,
|
|
"prev_node_target_type": prev_node_type,
|
|
"ref_node_name": node.name,
|
|
"ref_node_target_type": ref_node_type,
|
|
"index_within_arg": 0,
|
|
"index_of_arg": 0,
|
|
"fqn": fqn,
|
|
}
|
|
|
|
return None
|