Files
pytorch/torch/distributed/fsdp/_trace_utils.py
Nikita Shulga 634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00

238 lines
10 KiB
Python

import functools
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
import torch
import torch.nn as nn
@dataclass
class TracingConfig:
"""
This represents a symbolic tracing configuration.
Args:
tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to
use for symbolic tracing. The default value is the native
:class:`torch.fx.Tracer` constructed with default arguments.
However, the user may want to pass a different value such as the
``HFTracer`` for models in the HuggingFace Transformers_ library.
.. _Transformers: https://huggingface.co/docs/transformers/index
concrete_args (Optional[Dict[str, Any]]): Concrete arguments that
should not be treated as ``torch.fx.Proxy`` when tracing the
module ``forward()``. Passing ``concrete_args`` allows partially
specializing the forward, e.g. to remove control flow or data
structures. This ``concrete_args`` here is the same argument used
in :meth:`~torch.fx.Tracer.trace`.
"""
tracer: torch.fx.Tracer = torch.fx.Tracer()
concrete_args: Optional[Dict[str, Any]] = None
class _ParamUsageInfo(NamedTuple):
"""
This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record
execution information. The ``dict`` maps modules to a list of these
``_ParamUsageInfo`` instances, where each instance represents a group of
parameters used together.
Specifically, for each module key in the ``dict``, each instance of this
class represents either:
(1) the module and some sublist of its ``named_parameters()`` used
together in execution (see ``_patched_create_proxy()``), or
(2) a submodule and all of ``submodule.named_parameters()`` (see
``_patched_call_module()``).
Type (1) corresponds to directly using parameters in ops without calling
``forward()``, and type (2) corresponds to calling ``forward()``. The
mapped-to lists in the ``dict`` follow the execution order.
"""
module: nn.Module
named_params: List[Tuple[str, nn.Parameter]]
class _ExecutionInfo:
"""
This represents the execution order information from the forward pass.
Attributes:
curr_module (nn.Module): Current module being traced.
module_forward_order (List[nn.Module]): The modules in (pre-)forward
order, i.e. the order in which their ``forward()`` methods are
called. Each call to a module's ``forward()`` corresponds to one
element in the list.
module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]):
Maps a module to a list of module execution infos. See
:class:`_ParamUsageInfo` for details.
param_forward_order (List[nn.Parameter]): The parameters in forward
execution order, where only a parameter's first participation is
included.
visited_params (Set[nn.Parameter]): The parameters visited so far
during the trace. This is only used during tracing for fast
membership check. Invariant: The parameters in
``param_forward_order`` are exactly those in ``visited_params``.
"""
def __init__(self, root_module: nn.Module) -> None:
self.curr_module: nn.Module = root_module
self.module_forward_order: List[nn.Module] = [root_module]
self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = {
root_module: []
}
self.param_forward_order: List[nn.Parameter] = []
self.visited_params: Set[nn.Parameter] = set()
class _ExecOrderTracer:
def __init__(self) -> None:
self.exec_info: Optional[_ExecutionInfo] = None
@contextmanager
def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module):
self.exec_info = _ExecutionInfo(root_module)
orig_call_module = tracer.call_module
orig_create_proxy = tracer.create_proxy
tracer.call_module = functools.partial(
self._patched_call_module, orig_call_module, self.exec_info
)
fqn_to_param = dict(root_module.named_parameters())
tracer.create_proxy = functools.partial(
self._patched_create_proxy,
orig_create_proxy,
self.exec_info,
fqn_to_param,
)
try:
yield
finally:
tracer.call_module = orig_call_module
tracer.create_proxy = orig_create_proxy
def _patched_call_module(
self,
call_module: Callable,
exec_info: _ExecutionInfo,
# Below are the expected arguments to `call_module()`
module: nn.Module,
forward: Callable,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
"""
Overrides ``call_module`` to save execution information to
``exec_info``. Note that ``call_module`` is called during symbolic
tracing for each non-root module.
Args:
call_module (Callable): Original ``call_module`` to override.
exec_info (_ExecutionInfo): Used to record execution information.
module (nn.Module): Module corresponding to this ``call_module``.
forward (Callable): ``forward()`` method of ``module`` to be called
for this ``call_module``.
args (Tuple[Any, ...]): Positional arguments for ``forward``.
kwargs (Dict[str, Any]): Keyword arguments for ``forward``.
Returns:
Same return value as ``call_module``.
"""
exec_info.module_forward_order.append(module)
named_params = list(module.named_parameters())
curr_module = exec_info.curr_module
if named_params:
assert (
curr_module in exec_info.module_to_param_usage_infos
), "The current module should have already been processed by a patched `call_module`"
exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
_ParamUsageInfo(module, named_params)
)
prev_curr_module = curr_module
exec_info.curr_module = module
exec_info.module_to_param_usage_infos[module] = []
output = call_module(module, forward, args, kwargs)
exec_info.curr_module = prev_curr_module
return output
def _patched_create_proxy(
self,
create_proxy: Callable,
exec_info: _ExecutionInfo,
fqn_to_param: Dict[str, nn.Parameter],
# Below are the expected arguments to `create_proxy()`
kind: str,
target: torch.fx.node.Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
) -> torch.fx.Proxy:
"""
Overrides ``create_proxy`` to save execution information to
``exec_info``. Note that ``create_proxy`` is called during symbolic
tracing for each leaf function/method/module.
Args:
create_proxy (Callable): Original ``create_proxy`` to override.
exec_info (_ExecutionInfo): Used to record execution information.
fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the
root module's ``named_parameters()`` with FQN as key and
parameter as value.
kind (str): Kind of the target method ('call_function',
'call_method', 'get_attr', 'call_module', 'placeholder', or
'output'). See :class:`torch.fx.Graph` for details. This is
passed to ``create_proxy``.
target (torch.fx.node.Target): Contains the string name of the
function/method/module. This is passed to ``create_proxy``.
args (Tuple[Any, ...]): Positional arguments for the function/
method/module. This is passed to ``create_proxy``.
kwargs (Dict[str, Any]): Keyword arguments for the function/method/
module. This is passed to ``create_proxy``
name (Optional[str]): An optional string name for the ``Node``
created in ``create_proxy``. This is passed to
``create_proxy``.
type_expr (Optional[Any]): An optional type annotation representing
the Python type that the output of the node has. This is passed
to ``create_proxy``.
proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]):
An alternative proxy constructor used in ``create_proxy``. This
is passed to ``create_proxy``.
Returns:
torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object.
"""
proxy = create_proxy(
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
)
curr_module = exec_info.curr_module
if kind in ("call_function", "call_method"):
if args is not None:
named_params: List[Tuple[str, nn.Parameter]] = []
for arg in args:
if (
isinstance(arg, torch.fx.Proxy)
and arg.node.target in fqn_to_param
):
param = fqn_to_param[arg.node.target]
named_params.append((arg.node.target, param))
if param not in exec_info.visited_params:
exec_info.visited_params.add(param)
exec_info.param_forward_order.append(param)
if named_params:
exec_info.module_to_param_usage_infos[curr_module].append(
_ParamUsageInfo(curr_module, named_params)
)
elif kind == "call_module":
named_params = list(curr_module.named_parameters())
if named_params:
exec_info.module_to_param_usage_infos[curr_module].append(
_ParamUsageInfo(curr_module, named_params)
)
for _, param in named_params:
if param not in exec_info.visited_params:
exec_info.visited_params.add(param)
exec_info.param_forward_order.append(param)
return proxy