mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable mypy allow redefinition (#102046)
Related #101528
I tried to enable this in another PR but it uncovered a bunch of type errors: https://github.com/pytorch/pytorch/actions/runs/4999748262/jobs/8956555243?pr=101528#step:10:1305
The goal of this PR is to fix these errors.
---
This PR enables [allow_redefinition = True](https://mypy.readthedocs.io/en/stable/config_file.html#confval-allow_redefinition) in `mypy.ini`, which allows for a common pattern:
> Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition.
`allow_redefinition` allows mypy to be more flexible by allowing reassignment to an existing variable with a different type... for instance (from the linked PR):
4a1e9230ba/torch/nn/parallel/data_parallel.py (L213)
A `Sequence[Union[int, torch.device]]` is narrowed to `Sequence[int]` thru reassignment to the same variable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102046
Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf059e3925
commit
29da75cc55
@ -2,6 +2,7 @@
|
||||
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
|
||||
|
||||
cache_dir = .mypy_cache/nofollow
|
||||
allow_redefinition = True
|
||||
warn_unused_configs = True
|
||||
warn_redundant_casts = True
|
||||
show_error_codes = True
|
||||
|
@ -10,6 +10,7 @@ python_version = 3.8
|
||||
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
|
||||
|
||||
cache_dir = .mypy_cache/strict
|
||||
allow_redefinition = True
|
||||
strict_optional = True
|
||||
show_error_codes = True
|
||||
show_column_numbers = True
|
||||
|
1
mypy.ini
1
mypy.ini
@ -5,6 +5,7 @@
|
||||
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
|
||||
|
||||
cache_dir = .mypy_cache/normal
|
||||
allow_redefinition = True
|
||||
warn_unused_configs = True
|
||||
warn_redundant_casts = True
|
||||
show_error_codes = True
|
||||
|
@ -29,7 +29,7 @@ class TestFuture(TestCase):
|
||||
f.wait()
|
||||
|
||||
# Exception should also throw on value
|
||||
f = Future()
|
||||
f = Future[T]()
|
||||
f.set_exception(value_error)
|
||||
with self.assertRaisesRegex(ValueError, "Intentional"):
|
||||
f.value()
|
||||
@ -37,7 +37,7 @@ class TestFuture(TestCase):
|
||||
def cb(fut):
|
||||
fut.value()
|
||||
|
||||
f = Future()
|
||||
f = Future[T]()
|
||||
f.set_exception(value_error)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
|
||||
|
@ -90,7 +90,7 @@ pw_cast_for_int_to_real = partial(
|
||||
|
||||
|
||||
# This expands x until x.dim() == dim. Might be useful as an operator
|
||||
def _unsqueeze_to_dim(x: Tensor, dim: int):
|
||||
def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor:
|
||||
for _ in range(dim - x.dim()):
|
||||
x = x.unsqueeze(-1)
|
||||
return x
|
||||
|
@ -6,6 +6,7 @@ from types import MappingProxyType
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.overrides
|
||||
from torch._prims_common import (
|
||||
_torch_dtype_to_nvfuser_dtype_map,
|
||||
@ -181,24 +182,24 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
|
||||
class FusionInterpreter(torch.fx.Interpreter):
|
||||
def run_node(self, node):
|
||||
# Squeeze requires original shape of args[0]
|
||||
if node.target in [
|
||||
if node.target in (
|
||||
torch.ops.nvprims.squeeze,
|
||||
torch.ops.nvprims.squeeze.default,
|
||||
]:
|
||||
):
|
||||
original_shape = list(node.args[0].meta["tensor_meta"].shape)
|
||||
assert len(node.args) == 2
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
args = [args[0], original_shape, args[1]]
|
||||
args = args[:1] + (original_shape,) + args[1:]
|
||||
return self.call_function(node.target, args, node.kwargs)
|
||||
|
||||
if node.target in [
|
||||
if node.target in (
|
||||
torch.ops.nvprims.native_batch_norm,
|
||||
torch.ops.nvprims.native_batch_norm.default,
|
||||
]:
|
||||
):
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
assert len(args) == 8
|
||||
training = args[5]
|
||||
args6_end = tuple(map(_to_nvfuser_constant, args[6:]))
|
||||
args6_end = tuple(_to_nvfuser_constant(arg) for arg in args[6:])
|
||||
args = args[:5] + (training,) + args6_end
|
||||
return node.target.impl_nvfuser(fd, *args, **kwargs)
|
||||
|
||||
@ -209,7 +210,7 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
|
||||
if target == operator.getitem:
|
||||
assert isinstance(args[0], tuple)
|
||||
return target(*args, **kwargs)
|
||||
args = tuple(map(_to_nvfuser_constant, args))
|
||||
args = tuple(_to_nvfuser_constant(arg) for arg in args)
|
||||
target = target.impl_nvfuser
|
||||
args = (fd,) + args
|
||||
return target(*args, **kwargs)
|
||||
@ -239,7 +240,9 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
|
||||
return arg
|
||||
|
||||
# Transforms graph to call nvfuser lowerings
|
||||
nv_args = tuple(map(templates_to_nvfuser_inputs, nv_args_templates))
|
||||
nv_args = tuple(
|
||||
templates_to_nvfuser_inputs(nv_arg) for nv_arg in nv_args_templates
|
||||
)
|
||||
out = FusionInterpreter(gm).run(*nv_args)
|
||||
flat_out, unflatten_spec = tree_flatten(out)
|
||||
|
||||
|
@ -24,7 +24,7 @@ def prepare_pt2e(
|
||||
qconfig_mapping: QConfigMapping,
|
||||
example_inputs: Tuple[Any, ...],
|
||||
backend_config: BackendConfig,
|
||||
):
|
||||
) -> GraphModule:
|
||||
node_name_to_scope = _get_node_name_to_scope(model)
|
||||
|
||||
# TODO: check qconfig_mapping to make sure conv and bn are both configured
|
||||
@ -50,7 +50,7 @@ def prepare_pt2e(
|
||||
def prepare_pt2e_quantizer(
|
||||
model: GraphModule,
|
||||
quantizer: Quantizer,
|
||||
):
|
||||
) -> GraphModule:
|
||||
node_name_to_scope = _get_node_name_to_scope(model)
|
||||
# TODO: check qconfig_mapping to make sure conv and bn are both configured
|
||||
# to be quantized before fusion
|
||||
@ -66,7 +66,7 @@ def prepare_pt2e_quantizer(
|
||||
def prepare_qat_pt2e_quantizer(
|
||||
model: GraphModule,
|
||||
quantizer: Quantizer,
|
||||
):
|
||||
) -> GraphModule:
|
||||
node_name_to_scope = _get_node_name_to_scope(model)
|
||||
quantizer.annotate(model)
|
||||
quantizer.validate(model)
|
||||
@ -83,7 +83,7 @@ def prepare_qat_pt2e_quantizer(
|
||||
|
||||
def convert_pt2e(
|
||||
model: GraphModule
|
||||
):
|
||||
model = _convert_to_reference_decomposed_fx(model) # type: ignore[assignment]
|
||||
) -> GraphModule:
|
||||
model = _convert_to_reference_decomposed_fx(model)
|
||||
model = _fold_conv_bn_qat(model)
|
||||
return model
|
||||
|
@ -431,7 +431,7 @@ def _replace_observer_with_quantize_dequantize_node(
|
||||
# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
|
||||
# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
|
||||
# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
|
||||
def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph):
|
||||
def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph) -> None:
|
||||
call_custom_module_node = node.args[0]
|
||||
assert isinstance(call_custom_module_node, Node), \
|
||||
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
|
||||
@ -479,7 +479,7 @@ def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig)
|
||||
# run the weight observer
|
||||
weight_observer_module()
|
||||
|
||||
def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph):
|
||||
def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None:
|
||||
""" If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
|
||||
we'll recursively remove the dequantize Node
|
||||
"""
|
||||
@ -502,7 +502,7 @@ def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph):
|
||||
def _get_module_path_and_prefix(
|
||||
obs_node: Node,
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
node_name_to_qconfig: Dict[str, QConfigAny]):
|
||||
node_name_to_qconfig: Dict[str, QConfigAny]) -> Tuple[str, str]:
|
||||
""" Given and observer node, get the `Scope` or the fully qualified name for
|
||||
the submodule containing the observed node, also return a prefix of "_input"
|
||||
when the observed node is an input of a F.linear op, and not the output of another
|
||||
@ -549,7 +549,7 @@ def _get_module_path_and_prefix(
|
||||
|
||||
def _insert_dequantize_node(
|
||||
node: Node,
|
||||
graph: Graph):
|
||||
graph: Graph) -> None:
|
||||
""" Inserts dequantize node for `node` in `graph`
|
||||
"""
|
||||
with graph.inserting_after(node):
|
||||
@ -578,7 +578,7 @@ def convert_standalone_module(
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
model: torch.fx.GraphModule,
|
||||
is_reference: bool,
|
||||
backend_config: Optional[BackendConfig]):
|
||||
backend_config: Optional[BackendConfig]) -> None:
|
||||
""" Converts a observed standalone module to a quantized standalone module by calling
|
||||
the fx convert api, currently using the same `is_reference` flag as parent, but we may
|
||||
changing this behavior in the future (e.g. separating quantization and lowering for
|
||||
@ -641,7 +641,7 @@ def convert_weighted_module(
|
||||
observed_node_names: Set[str],
|
||||
node_name_to_qconfig: Dict[str, QConfigAny],
|
||||
backend_config: BackendConfig,
|
||||
is_decomposed: bool = False):
|
||||
is_decomposed: bool = False) -> None:
|
||||
""" Convert a weighted module to reference quantized module in the model
|
||||
If the QConfig of a QAT module is not set, the module will still be converted to
|
||||
a float module.
|
||||
@ -746,7 +746,7 @@ def convert_weighted_module(
|
||||
parent_name, name = _parent_name(node.target)
|
||||
setattr(modules[parent_name], name, ref_qmodule)
|
||||
|
||||
def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph):
|
||||
def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph) -> None:
|
||||
"""
|
||||
Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
|
||||
|
||||
@ -768,7 +768,7 @@ def convert_custom_module(
|
||||
graph: Graph,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]],
|
||||
statically_quantized_custom_module_nodes: Set[Node]):
|
||||
statically_quantized_custom_module_nodes: Set[Node]) -> None:
|
||||
""" Converts an observed custom module to a quantized custom module based on
|
||||
`custom_module_class_mapping`
|
||||
For static quantization, we'll also remove the previous `dequantize` node and
|
||||
@ -853,7 +853,7 @@ def convert(
|
||||
_remove_qconfig_flag: bool = True,
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
is_decomposed: bool = False) -> torch.nn.Module:
|
||||
is_decomposed: bool = False) -> GraphModule:
|
||||
"""
|
||||
We will convert an observed model (a module with observer calls) to a reference
|
||||
quantized model, the rule is simple:
|
||||
|
@ -1333,7 +1333,7 @@ def insert_observers_for_model(
|
||||
# if not, we'll reset the target_dtye_info to use the default (float Tensor)
|
||||
|
||||
# reset the counters and set of processed_nodes
|
||||
processed_nodes = set()
|
||||
processed_nodes: Set[Node] = set()
|
||||
for node_name, match_res_with_qconfig in node_name_to_match_result_with_qconfig.items():
|
||||
last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
|
||||
is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
|
||||
|
@ -28,7 +28,9 @@ from .fx.utils import get_skipped_module_name_and_classes
|
||||
from .qconfig_mapping import QConfigMapping
|
||||
|
||||
def attach_preserved_attrs_to_model(
|
||||
model: Union[GraphModule, torch.nn.Module], preserved_attrs: Dict[str, Any]):
|
||||
model: Union[GraphModule, torch.nn.Module],
|
||||
preserved_attrs: Dict[str, Any],
|
||||
) -> None:
|
||||
""" Store preserved attributes to the model.meta so that it can be preserved during deepcopy
|
||||
"""
|
||||
model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment]
|
||||
@ -47,7 +49,7 @@ def _check_is_graph_module(model: torch.nn.Module) -> None:
|
||||
+ "sure to follow the tutorials."
|
||||
)
|
||||
|
||||
def _attach_meta_to_node_if_not_exist(model: GraphModule):
|
||||
def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None:
|
||||
""" Attach meta field to all nodes of the graph if it does not exist,
|
||||
meta field is a field stores some meta information about the node, such
|
||||
as dtype and shape information for output of the node, this only exists
|
||||
@ -503,7 +505,7 @@ def _convert_fx(
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
is_decomposed: bool = False,
|
||||
) -> torch.nn.Module:
|
||||
) -> GraphModule:
|
||||
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
|
||||
"""
|
||||
if convert_custom_config is None:
|
||||
@ -540,7 +542,7 @@ def convert_fx(
|
||||
_remove_qconfig: bool = True,
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
) -> torch.nn.Module:
|
||||
) -> GraphModule:
|
||||
r""" Convert a calibrated or trained model to a quantized model
|
||||
|
||||
Args:
|
||||
@ -607,7 +609,7 @@ def convert_to_reference_fx(
|
||||
_remove_qconfig: bool = True,
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
) -> torch.nn.Module:
|
||||
) -> GraphModule:
|
||||
r""" Convert a calibrated or trained model to a reference quantized model,
|
||||
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
|
||||
reference quantized model is a standard representation of a quantized model provided
|
||||
@ -656,7 +658,7 @@ def _convert_to_reference_decomposed_fx(
|
||||
_remove_qconfig: bool = True,
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||||
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||||
) -> torch.nn.Module:
|
||||
) -> GraphModule:
|
||||
r""" Convert a calibrated or trained model to a reference quantized model, with
|
||||
decomposed representation for quantized Tensor
|
||||
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
|
||||
@ -708,7 +710,7 @@ def _convert_standalone_module_fx(
|
||||
graph_module: GraphModule,
|
||||
is_reference: bool = False,
|
||||
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
|
||||
) -> torch.nn.Module:
|
||||
) -> GraphModule:
|
||||
r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
|
||||
and convert it to a quantized model
|
||||
|
||||
|
@ -124,7 +124,7 @@ class _lazy_property_and_property(lazy_property, property):
|
||||
property.__init__(self, wrapped)
|
||||
|
||||
|
||||
def tril_matrix_to_vec(mat, diag=0):
|
||||
def tril_matrix_to_vec(mat: torch.Tensor, diag: int = 0) -> torch.Tensor:
|
||||
r"""
|
||||
Convert a `D x D` matrix or a batch of matrices into a (batched) vector
|
||||
which comprises of lower triangular elements from the matrix in row order.
|
||||
@ -138,7 +138,7 @@ def tril_matrix_to_vec(mat, diag=0):
|
||||
return vec
|
||||
|
||||
|
||||
def vec_to_tril_matrix(vec, diag=0):
|
||||
def vec_to_tril_matrix(vec: torch.Tensor, diag: int = 0) -> torch.Tensor:
|
||||
r"""
|
||||
Convert a vector or a batch of vectors into a batched `D x D`
|
||||
lower triangular matrix containing elements from the vector in row order.
|
||||
@ -149,7 +149,7 @@ def vec_to_tril_matrix(vec, diag=0):
|
||||
if not torch._C._get_tracing_state() and (round(n) - n > eps):
|
||||
raise ValueError(f'The size of last dimension is {vec.shape[-1]} which cannot be expressed as ' +
|
||||
'the lower triangular part of a square D x D matrix.')
|
||||
n = torch.round(n).long() if isinstance(n, torch.Tensor) else round(n)
|
||||
n = round(n.item()) if isinstance(n, torch.Tensor) else round(n)
|
||||
mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n)))
|
||||
arange = torch.arange(n, device=vec.device)
|
||||
tril_mask = arange < arange.view(-1, 1) + (diag + 1)
|
||||
|
@ -194,12 +194,12 @@ class CapabilityBasedPartitioner:
|
||||
for id, partition in partitions_by_id.items():
|
||||
compute_node_count = 0
|
||||
for node in partition.nodes:
|
||||
if node.op == "call_function" and \
|
||||
_get_qualified_name(node.target) not in non_compute_ops: # type: ignore[arg-type]
|
||||
compute_node_count += 1
|
||||
if node.op == "call_function" and \
|
||||
_get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
|
||||
compute_node_count += 1
|
||||
if node.op == "call_function":
|
||||
assert callable(node.target)
|
||||
if _get_qualified_name(node.target) not in non_compute_ops:
|
||||
compute_node_count += 1
|
||||
if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
|
||||
compute_node_count += 1
|
||||
if compute_node_count <= 1:
|
||||
partitions_to_remove.append(id)
|
||||
for id in partitions_to_remove:
|
||||
|
@ -1462,14 +1462,14 @@ def embedding_bag(
|
||||
def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type):
|
||||
unique_indices = g.op("Unique", indices)
|
||||
partial_weight = g.op("Gather", weight, unique_indices)
|
||||
norm_type = int(norm_type)
|
||||
if norm_type == 1:
|
||||
norm_i = int(norm_type)
|
||||
if norm_i == 1:
|
||||
norm_type = "ReduceL1"
|
||||
elif norm_type == 2:
|
||||
elif norm_i == 2:
|
||||
norm_type = "ReduceL2"
|
||||
else:
|
||||
raise errors.SymbolicValueError(
|
||||
f"Unsupported: ONNX export of embedding_renorm with norm: {norm_type}. "
|
||||
f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. "
|
||||
"Only 1. and 2. are supported.",
|
||||
weight,
|
||||
)
|
||||
|
@ -2025,8 +2025,9 @@ def _write_ninja_file_to_build_library(path,
|
||||
cuda_flags += extra_cuda_cflags
|
||||
if not any(flag.startswith('-std=') for flag in cuda_flags):
|
||||
cuda_flags.append('-std=c++17')
|
||||
if os.getenv("CC") is not None:
|
||||
cuda_flags = ['-ccbin', os.getenv("CC")] + cuda_flags
|
||||
cc_env = os.getenv("CC")
|
||||
if cc_env is not None:
|
||||
cuda_flags = ['-ccbin', cc_env] + cuda_flags
|
||||
else:
|
||||
cuda_flags = None
|
||||
|
||||
|
Reference in New Issue
Block a user