mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[reland][quant] Add utility function get_fqn_to_example_inputs
Summary: After https://github.com/pytorch/pytorch/pull/77608 `example_inputs` is required input for `prepare_fx` and `prepare_qat_fx`. This makes quantizing submodules harder, so we added this utility function to get a dictionary from fqn to submodule example_inputs Example Call: ``` example_inputs = (tensor0,) get_fqn_to_example_inputs(m, example_inputs) ``` Example output: ``` { "linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,), "sub.linear1": (tensor4,), ... } ``` Test Plan: python test/test_quantization.py TestUtils Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/78286 Approved by: https://github.com/dzdang
This commit is contained in:
committed by
PyTorch MergeBot
parent
56c23f5633
commit
7ea5fa3dd4
@ -478,16 +478,6 @@
|
||||
"QuantType",
|
||||
"wrap_cpp_module"
|
||||
],
|
||||
"torch.ao.quantization.utils": [
|
||||
"Any",
|
||||
"Callable",
|
||||
"Pattern",
|
||||
"QuantType",
|
||||
"Tuple",
|
||||
"Union",
|
||||
"is_parametrized",
|
||||
"quant_type_to_str"
|
||||
],
|
||||
"torch.ao.sparsity.experimental.pruner.base_pruner": [
|
||||
"ActivationReconstruction",
|
||||
"BaseSparsifier",
|
||||
|
||||
128
test/quantization/core/test_utils.py
Normal file
128
test/quantization/core/test_utils.py
Normal file
@ -0,0 +1,128 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.ao.quantization.utils import get_fqn_to_example_inputs
|
||||
|
||||
|
||||
class TestUtils(TestCase):
|
||||
def _test_get_fqn_to_example_inputs(self, M, example_inputs, expected_fqn_to_dim):
|
||||
m = M().eval()
|
||||
fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs)
|
||||
for fqn, expected_dims in expected_fqn_to_dim.items():
|
||||
assert fqn in expected_fqn_to_dim
|
||||
example_inputs = fqn_to_example_inputs[fqn]
|
||||
for example_input, expected_dim in zip(example_inputs, expected_dims):
|
||||
assert example_input.dim() == expected_dim
|
||||
|
||||
def test_get_fqn_to_example_inputs_simple(self):
|
||||
class Sub(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
self.sub = Sub()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.sub(x)
|
||||
return x
|
||||
|
||||
expected_fqn_to_dim = {
|
||||
"": (2,),
|
||||
"linear1": (2,),
|
||||
"linear2": (2,),
|
||||
"sub": (2,),
|
||||
"sub.linear1": (2,),
|
||||
"sub.linear2": (2,)
|
||||
}
|
||||
example_inputs = (torch.rand(1, 5),)
|
||||
self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim)
|
||||
|
||||
def test_get_fqn_to_example_inputs_default_kwargs(self):
|
||||
""" Test that we can get example inputs for functions with default keyword arguments
|
||||
"""
|
||||
class Sub(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x, key1=torch.rand(1), key2=torch.rand(1)):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
self.sub = Sub()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
# only override `key2`, `key1` will use default
|
||||
x = self.sub(x, key2=torch.rand(1, 2))
|
||||
return x
|
||||
|
||||
expected_fqn_to_dim = {
|
||||
"": (2,),
|
||||
"linear1": (2,),
|
||||
"linear2": (2,),
|
||||
# second arg is `key1`, which is using default argument
|
||||
# third arg is `key2`, override by callsite
|
||||
"sub": (2, 1, 2),
|
||||
"sub.linear1": (2,),
|
||||
"sub.linear2": (2,)
|
||||
}
|
||||
example_inputs = (torch.rand(1, 5),)
|
||||
self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim)
|
||||
|
||||
def test_get_fqn_to_example_inputs_complex_args(self):
|
||||
""" Test that we can record complex example inputs such as lists and dicts
|
||||
"""
|
||||
class Sub(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x, list_arg, dict_arg):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
self.sub = Sub()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.sub(x, [x], {"3": x})
|
||||
return x
|
||||
|
||||
example_inputs = (torch.rand(1, 5),)
|
||||
m = M().eval()
|
||||
fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs)
|
||||
assert "sub" in fqn_to_example_inputs
|
||||
assert isinstance(fqn_to_example_inputs["sub"][1], list)
|
||||
assert isinstance(fqn_to_example_inputs["sub"][2], dict) and \
|
||||
"3" in fqn_to_example_inputs["sub"][2]
|
||||
@ -36,6 +36,7 @@ from quantization.core.test_workflow_module import TestRecordHistogramObserver
|
||||
from quantization.core.test_workflow_module import TestHistogramObserver # noqa: F401
|
||||
from quantization.core.test_workflow_module import TestDistributed # noqa: F401
|
||||
from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401
|
||||
from quantization.core.test_utils import TestUtils # noqa: F401
|
||||
|
||||
|
||||
# Eager Mode Workflow. Tests for the functionality of APIs and different features implemented
|
||||
|
||||
@ -5,13 +5,18 @@ import warnings
|
||||
import functools
|
||||
import torch
|
||||
from torch.ao.quantization.quant_type import QuantType, quant_type_to_str
|
||||
from typing import Tuple, Any, Union, Callable
|
||||
from typing import Tuple, Any, Union, Callable, Dict, Optional
|
||||
import typing
|
||||
from torch.nn.utils.parametrize import is_parametrized
|
||||
from collections import OrderedDict
|
||||
from inspect import signature
|
||||
from inspect import getfullargspec
|
||||
|
||||
# Type for fusion patterns, it can be more complicated than the following actually,
|
||||
# see pattern.md for docs
|
||||
# TODO: not sure if typing supports recursive data types
|
||||
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any]
|
||||
Pattern.__module__ = "torch.ao.quantization.utils"
|
||||
|
||||
# TODO: maybe rename this to MatchInputNode
|
||||
class MatchAllNode:
|
||||
@ -95,6 +100,7 @@ method_list = {
|
||||
'view',
|
||||
}
|
||||
|
||||
# TODO: not used now, remove
|
||||
def check_node(node, modules):
|
||||
# TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
|
||||
is_call_function = node.op == "call_function" and node.target in func_list
|
||||
@ -370,3 +376,181 @@ def has_no_children_ignoring_parametrizations(module):
|
||||
return len(module._modules) == 1 and 'parametrizations' in module._modules
|
||||
else:
|
||||
return False
|
||||
|
||||
def _get_path_of_module(root: torch.nn.Module, submodule: torch.nn.Module) -> Optional[str]:
|
||||
""" Get the path (fully qualified name) of a submodule
|
||||
|
||||
Example::
|
||||
|
||||
>> class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
>> m = M()
|
||||
>> l = m.linear
|
||||
>> _get_path_of_module(m, l)
|
||||
"linear"
|
||||
"""
|
||||
for n, p in root.named_modules():
|
||||
if submodule is p:
|
||||
return n
|
||||
return None
|
||||
|
||||
def _get_signature_locals(f: Callable, loc: Dict[str, Any]) -> Dict[str, Any]:
|
||||
""" Get local keyword arguments
|
||||
|
||||
Example::
|
||||
|
||||
>> def f(self, a, b=9):
|
||||
pass
|
||||
>> loc = {"a": 6, "c": 7}
|
||||
>> _get_signature_locals(f, loc)
|
||||
{"a": 6}
|
||||
"""
|
||||
return {k: v for k, v in loc.items() if k in signature(f).parameters}
|
||||
|
||||
def _get_default_kwargs(f: Callable) -> typing.OrderedDict[str, Any]:
|
||||
""" Get all default keyword arguments from function signature
|
||||
|
||||
Example::
|
||||
|
||||
>> def f(self, a, b=9):
|
||||
pass
|
||||
>> _get_default_kwargs(f)
|
||||
{"b": 9}
|
||||
"""
|
||||
kwargs = {}
|
||||
for name, param in signature(f).parameters.items():
|
||||
if param.default is not param.empty:
|
||||
kwargs[name] = param.default
|
||||
elif param.kind is param.VAR_POSITIONAL:
|
||||
kwargs[name] = ()
|
||||
elif param.kind is param.VAR_KEYWORD:
|
||||
kwargs[name] = {}
|
||||
return OrderedDict(kwargs)
|
||||
|
||||
def _normalize_kwargs(func: Callable, loc: Dict[str, Any]) -> typing.OrderedDict[str, Any]:
|
||||
""" Given a function and local function arguments, normalize the keyword
|
||||
arguments by filling in default arguments from function signature
|
||||
|
||||
Example::
|
||||
|
||||
>> def f(self, key1=3, key2=3):
|
||||
pass
|
||||
>> loc = {"key2": 6}
|
||||
>> _normalize_kwargs(f, loc)
|
||||
{"key1": 3, "key2": 6}
|
||||
"""
|
||||
default_kwargs = _get_default_kwargs(func)
|
||||
local_kwargs = _get_signature_locals(func, loc)
|
||||
normalized_kwargs = default_kwargs.copy()
|
||||
for attr, val in local_kwargs.items():
|
||||
if attr in normalized_kwargs:
|
||||
# override the default keyword arguments
|
||||
normalized_kwargs[attr] = val
|
||||
return normalized_kwargs
|
||||
|
||||
def _get_num_pos_args(f: Callable) -> int:
|
||||
""" Get number of positional args for a function
|
||||
|
||||
Example::
|
||||
|
||||
>> def f(self, key1=3, key2=3):
|
||||
pass
|
||||
>> _get_num_pos_args(f)
|
||||
3
|
||||
"""
|
||||
return len(getfullargspec(f).args)
|
||||
|
||||
def get_fqn_to_example_inputs(
|
||||
model: torch.nn.Module,
|
||||
example_inputs: Tuple[Any, ...]
|
||||
) -> Dict[str, Tuple[Any, ...]]:
|
||||
""" Given a model and its example inputs, return a dictionary from
|
||||
fully qualified name of submodules to example_inputs for that submodule,
|
||||
e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,),
|
||||
"sub.linear1": (tensor4,), ...}
|
||||
|
||||
Used to make quantizing submodules easier now that FX Graph Mode Quantization requries
|
||||
example inputs.
|
||||
|
||||
Also works for keyword arguments with default values, we would flatten keyword
|
||||
arguments as positional arguments and fill in the missing keyword args with default
|
||||
values, e.g. if we have a forward function:
|
||||
def forward(self, x, key1=3, key2=3):
|
||||
...
|
||||
|
||||
and we call it with self.submodule(x, key2=6)
|
||||
we'll get example_inputs: (x, 3, 6)
|
||||
|
||||
user can also override `key1` with positional arguments as well:
|
||||
for self.submodule(x, 5, key2=6)
|
||||
we'll get: (x, 5, 6)
|
||||
|
||||
variable positional arguments and variable positional keyword arguments in forward
|
||||
function are not supported currently, so please make sure no submodules is using
|
||||
them.
|
||||
"""
|
||||
root = model
|
||||
fqn_to_example_inputs = {}
|
||||
|
||||
class InterceptionModule(type(model)): # type: ignore[misc]
|
||||
def __call__(self, *args, **kwargs):
|
||||
orig_module_call = torch.nn.Module.__call__
|
||||
|
||||
def _patched_module_call(self, *args, **kwargs):
|
||||
submodule_example_inputs = list(args).copy()
|
||||
normalized_kwargs = _normalize_kwargs(self.forward, kwargs)
|
||||
# minus 1 to skipping counting `self`
|
||||
num_args = _get_num_pos_args(self.forward) - 1
|
||||
num_to_pop = num_args - len(submodule_example_inputs)
|
||||
while num_to_pop and normalized_kwargs:
|
||||
normalized_kwargs.popitem(last=False)
|
||||
num_to_pop -= 1
|
||||
submodule_example_inputs.extend(normalized_kwargs.values())
|
||||
submodule_example_inputs_tuple = tuple(submodule_example_inputs)
|
||||
fqn = _get_path_of_module(root, self)
|
||||
if fqn is not None:
|
||||
fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple
|
||||
return orig_module_call(self, *args, **kwargs)
|
||||
|
||||
torch.nn.Module.__call__ = _patched_module_call
|
||||
super().__call__(*args, **kwargs)
|
||||
torch.nn.Module.__call__ = orig_module_call
|
||||
|
||||
original_class = model.__class__
|
||||
model.__class__ = InterceptionModule
|
||||
model(*example_inputs)
|
||||
model.__class__ = original_class
|
||||
|
||||
return fqn_to_example_inputs
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Pattern",
|
||||
"MatchAllNode",
|
||||
"check_node",
|
||||
"get_combined_dict",
|
||||
"is_per_tensor",
|
||||
"is_per_channel",
|
||||
"getattr_from_fqn",
|
||||
"get_qparam_dict",
|
||||
"get_swapped_custom_module_class",
|
||||
"activation_dtype",
|
||||
"weight_dtype",
|
||||
"activation_is_statically_quantized",
|
||||
"activation_is_dynamically_quantized",
|
||||
"activation_is_int8_quantized",
|
||||
"activation_is_int32_quantized",
|
||||
"weight_is_quantized",
|
||||
"weight_is_statically_quantized",
|
||||
"op_is_int8_dynamically_quantized",
|
||||
"get_qconfig_dtypes",
|
||||
"get_quant_type",
|
||||
"check_min_max_valid",
|
||||
"calculate_qmin_qmax",
|
||||
"has_no_children_ignoring_parametrizations",
|
||||
"get_fqn_to_example_inputs",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user