[reland][quant] Add utility function get_fqn_to_example_inputs

Summary:
After https://github.com/pytorch/pytorch/pull/77608 `example_inputs` is required input for `prepare_fx` and `prepare_qat_fx`.
This makes quantizing submodules harder, so we added this utility function to get a dictionary from fqn to submodule example_inputs

Example Call:

```
example_inputs = (tensor0,)
get_fqn_to_example_inputs(m, example_inputs)
```

Example output:
```
{
   "linear1": (tensor1,),
   "linear2": (tensor2,),
   "sub": (tensor3,),
   "sub.linear1": (tensor4,),
   ...
}
```

Test Plan:
python test/test_quantization.py TestUtils

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78286

Approved by: https://github.com/dzdang
This commit is contained in:
Jerry Zhang
2022-05-25 13:28:58 -07:00
committed by PyTorch MergeBot
parent 56c23f5633
commit 7ea5fa3dd4
4 changed files with 314 additions and 11 deletions

View File

@ -478,16 +478,6 @@
"QuantType",
"wrap_cpp_module"
],
"torch.ao.quantization.utils": [
"Any",
"Callable",
"Pattern",
"QuantType",
"Tuple",
"Union",
"is_parametrized",
"quant_type_to_str"
],
"torch.ao.sparsity.experimental.pruner.base_pruner": [
"ActivationReconstruction",
"BaseSparsifier",

View File

@ -0,0 +1,128 @@
# Owner(s): ["oncall: quantization"]
import torch
from torch.testing._internal.common_utils import TestCase
from torch.ao.quantization.utils import get_fqn_to_example_inputs
class TestUtils(TestCase):
def _test_get_fqn_to_example_inputs(self, M, example_inputs, expected_fqn_to_dim):
m = M().eval()
fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs)
for fqn, expected_dims in expected_fqn_to_dim.items():
assert fqn in expected_fqn_to_dim
example_inputs = fqn_to_example_inputs[fqn]
for example_input, expected_dim in zip(example_inputs, expected_dims):
assert example_input.dim() == expected_dim
def test_get_fqn_to_example_inputs_simple(self):
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
self.sub = Sub()
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.sub(x)
return x
expected_fqn_to_dim = {
"": (2,),
"linear1": (2,),
"linear2": (2,),
"sub": (2,),
"sub.linear1": (2,),
"sub.linear2": (2,)
}
example_inputs = (torch.rand(1, 5),)
self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim)
def test_get_fqn_to_example_inputs_default_kwargs(self):
""" Test that we can get example inputs for functions with default keyword arguments
"""
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
def forward(self, x, key1=torch.rand(1), key2=torch.rand(1)):
x = self.linear1(x)
x = self.linear2(x)
return x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
self.sub = Sub()
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
# only override `key2`, `key1` will use default
x = self.sub(x, key2=torch.rand(1, 2))
return x
expected_fqn_to_dim = {
"": (2,),
"linear1": (2,),
"linear2": (2,),
# second arg is `key1`, which is using default argument
# third arg is `key2`, override by callsite
"sub": (2, 1, 2),
"sub.linear1": (2,),
"sub.linear2": (2,)
}
example_inputs = (torch.rand(1, 5),)
self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim)
def test_get_fqn_to_example_inputs_complex_args(self):
""" Test that we can record complex example inputs such as lists and dicts
"""
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
def forward(self, x, list_arg, dict_arg):
x = self.linear1(x)
x = self.linear2(x)
return x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
self.sub = Sub()
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.sub(x, [x], {"3": x})
return x
example_inputs = (torch.rand(1, 5),)
m = M().eval()
fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs)
assert "sub" in fqn_to_example_inputs
assert isinstance(fqn_to_example_inputs["sub"][1], list)
assert isinstance(fqn_to_example_inputs["sub"][2], dict) and \
"3" in fqn_to_example_inputs["sub"][2]

View File

@ -36,6 +36,7 @@ from quantization.core.test_workflow_module import TestRecordHistogramObserver
from quantization.core.test_workflow_module import TestHistogramObserver # noqa: F401
from quantization.core.test_workflow_module import TestDistributed # noqa: F401
from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401
from quantization.core.test_utils import TestUtils # noqa: F401
# Eager Mode Workflow. Tests for the functionality of APIs and different features implemented

View File

@ -5,13 +5,18 @@ import warnings
import functools
import torch
from torch.ao.quantization.quant_type import QuantType, quant_type_to_str
from typing import Tuple, Any, Union, Callable
from typing import Tuple, Any, Union, Callable, Dict, Optional
import typing
from torch.nn.utils.parametrize import is_parametrized
from collections import OrderedDict
from inspect import signature
from inspect import getfullargspec
# Type for fusion patterns, it can be more complicated than the following actually,
# see pattern.md for docs
# TODO: not sure if typing supports recursive data types
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any]
Pattern.__module__ = "torch.ao.quantization.utils"
# TODO: maybe rename this to MatchInputNode
class MatchAllNode:
@ -95,6 +100,7 @@ method_list = {
'view',
}
# TODO: not used now, remove
def check_node(node, modules):
# TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
is_call_function = node.op == "call_function" and node.target in func_list
@ -370,3 +376,181 @@ def has_no_children_ignoring_parametrizations(module):
return len(module._modules) == 1 and 'parametrizations' in module._modules
else:
return False
def _get_path_of_module(root: torch.nn.Module, submodule: torch.nn.Module) -> Optional[str]:
""" Get the path (fully qualified name) of a submodule
Example::
>> class M(torch.nn.Module):
def __init__(self):
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
>> m = M()
>> l = m.linear
>> _get_path_of_module(m, l)
"linear"
"""
for n, p in root.named_modules():
if submodule is p:
return n
return None
def _get_signature_locals(f: Callable, loc: Dict[str, Any]) -> Dict[str, Any]:
""" Get local keyword arguments
Example::
>> def f(self, a, b=9):
pass
>> loc = {"a": 6, "c": 7}
>> _get_signature_locals(f, loc)
{"a": 6}
"""
return {k: v for k, v in loc.items() if k in signature(f).parameters}
def _get_default_kwargs(f: Callable) -> typing.OrderedDict[str, Any]:
""" Get all default keyword arguments from function signature
Example::
>> def f(self, a, b=9):
pass
>> _get_default_kwargs(f)
{"b": 9}
"""
kwargs = {}
for name, param in signature(f).parameters.items():
if param.default is not param.empty:
kwargs[name] = param.default
elif param.kind is param.VAR_POSITIONAL:
kwargs[name] = ()
elif param.kind is param.VAR_KEYWORD:
kwargs[name] = {}
return OrderedDict(kwargs)
def _normalize_kwargs(func: Callable, loc: Dict[str, Any]) -> typing.OrderedDict[str, Any]:
""" Given a function and local function arguments, normalize the keyword
arguments by filling in default arguments from function signature
Example::
>> def f(self, key1=3, key2=3):
pass
>> loc = {"key2": 6}
>> _normalize_kwargs(f, loc)
{"key1": 3, "key2": 6}
"""
default_kwargs = _get_default_kwargs(func)
local_kwargs = _get_signature_locals(func, loc)
normalized_kwargs = default_kwargs.copy()
for attr, val in local_kwargs.items():
if attr in normalized_kwargs:
# override the default keyword arguments
normalized_kwargs[attr] = val
return normalized_kwargs
def _get_num_pos_args(f: Callable) -> int:
""" Get number of positional args for a function
Example::
>> def f(self, key1=3, key2=3):
pass
>> _get_num_pos_args(f)
3
"""
return len(getfullargspec(f).args)
def get_fqn_to_example_inputs(
model: torch.nn.Module,
example_inputs: Tuple[Any, ...]
) -> Dict[str, Tuple[Any, ...]]:
""" Given a model and its example inputs, return a dictionary from
fully qualified name of submodules to example_inputs for that submodule,
e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,),
"sub.linear1": (tensor4,), ...}
Used to make quantizing submodules easier now that FX Graph Mode Quantization requries
example inputs.
Also works for keyword arguments with default values, we would flatten keyword
arguments as positional arguments and fill in the missing keyword args with default
values, e.g. if we have a forward function:
def forward(self, x, key1=3, key2=3):
...
and we call it with self.submodule(x, key2=6)
we'll get example_inputs: (x, 3, 6)
user can also override `key1` with positional arguments as well:
for self.submodule(x, 5, key2=6)
we'll get: (x, 5, 6)
variable positional arguments and variable positional keyword arguments in forward
function are not supported currently, so please make sure no submodules is using
them.
"""
root = model
fqn_to_example_inputs = {}
class InterceptionModule(type(model)): # type: ignore[misc]
def __call__(self, *args, **kwargs):
orig_module_call = torch.nn.Module.__call__
def _patched_module_call(self, *args, **kwargs):
submodule_example_inputs = list(args).copy()
normalized_kwargs = _normalize_kwargs(self.forward, kwargs)
# minus 1 to skipping counting `self`
num_args = _get_num_pos_args(self.forward) - 1
num_to_pop = num_args - len(submodule_example_inputs)
while num_to_pop and normalized_kwargs:
normalized_kwargs.popitem(last=False)
num_to_pop -= 1
submodule_example_inputs.extend(normalized_kwargs.values())
submodule_example_inputs_tuple = tuple(submodule_example_inputs)
fqn = _get_path_of_module(root, self)
if fqn is not None:
fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple
return orig_module_call(self, *args, **kwargs)
torch.nn.Module.__call__ = _patched_module_call
super().__call__(*args, **kwargs)
torch.nn.Module.__call__ = orig_module_call
original_class = model.__class__
model.__class__ = InterceptionModule
model(*example_inputs)
model.__class__ = original_class
return fqn_to_example_inputs
__all__ = [
"Pattern",
"MatchAllNode",
"check_node",
"get_combined_dict",
"is_per_tensor",
"is_per_channel",
"getattr_from_fqn",
"get_qparam_dict",
"get_swapped_custom_module_class",
"activation_dtype",
"weight_dtype",
"activation_is_statically_quantized",
"activation_is_dynamically_quantized",
"activation_is_int8_quantized",
"activation_is_int32_quantized",
"weight_is_quantized",
"weight_is_statically_quantized",
"op_is_int8_dynamically_quantized",
"get_qconfig_dtypes",
"get_quant_type",
"check_min_max_valid",
"calculate_qmin_qmax",
"has_no_children_ignoring_parametrizations",
"get_fqn_to_example_inputs",
]