Enable mypy allow redefinition (#102046)

Related #101528

I tried to enable this in another PR but it uncovered a bunch of type errors: https://github.com/pytorch/pytorch/actions/runs/4999748262/jobs/8956555243?pr=101528#step:10:1305

The goal of this PR is to fix these errors.

---

This PR enables [allow_redefinition = True](https://mypy.readthedocs.io/en/stable/config_file.html#confval-allow_redefinition) in `mypy.ini`, which allows for a common pattern:

> Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition.

`allow_redefinition` allows mypy to be more flexible by allowing reassignment to an existing variable with a different type... for instance (from the linked PR):

4a1e9230ba/torch/nn/parallel/data_parallel.py (L213)

A `Sequence[Union[int, torch.device]]` is narrowed to `Sequence[int]` thru reassignment to the same variable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102046
Approved by: https://github.com/ezyang
This commit is contained in:
Matthew Hoffman
2023-05-24 07:05:24 +00:00
committed by PyTorch MergeBot
parent bf059e3925
commit 29da75cc55
14 changed files with 57 additions and 48 deletions

View File

@ -2,6 +2,7 @@
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
cache_dir = .mypy_cache/nofollow
allow_redefinition = True
warn_unused_configs = True
warn_redundant_casts = True
show_error_codes = True

View File

@ -10,6 +10,7 @@ python_version = 3.8
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
cache_dir = .mypy_cache/strict
allow_redefinition = True
strict_optional = True
show_error_codes = True
show_column_numbers = True

View File

@ -5,6 +5,7 @@
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
cache_dir = .mypy_cache/normal
allow_redefinition = True
warn_unused_configs = True
warn_redundant_casts = True
show_error_codes = True

View File

@ -29,7 +29,7 @@ class TestFuture(TestCase):
f.wait()
# Exception should also throw on value
f = Future()
f = Future[T]()
f.set_exception(value_error)
with self.assertRaisesRegex(ValueError, "Intentional"):
f.value()
@ -37,7 +37,7 @@ class TestFuture(TestCase):
def cb(fut):
fut.value()
f = Future()
f = Future[T]()
f.set_exception(value_error)
with self.assertRaisesRegex(RuntimeError, "Got the following error"):

View File

@ -90,7 +90,7 @@ pw_cast_for_int_to_real = partial(
# This expands x until x.dim() == dim. Might be useful as an operator
def _unsqueeze_to_dim(x: Tensor, dim: int):
def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor:
for _ in range(dim - x.dim()):
x = x.unsqueeze(-1)
return x

View File

@ -6,6 +6,7 @@ from types import MappingProxyType
from warnings import warn
import torch
import torch.fx
import torch.overrides
from torch._prims_common import (
_torch_dtype_to_nvfuser_dtype_map,
@ -181,24 +182,24 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
class FusionInterpreter(torch.fx.Interpreter):
def run_node(self, node):
# Squeeze requires original shape of args[0]
if node.target in [
if node.target in (
torch.ops.nvprims.squeeze,
torch.ops.nvprims.squeeze.default,
]:
):
original_shape = list(node.args[0].meta["tensor_meta"].shape)
assert len(node.args) == 2
args, kwargs = self.fetch_args_kwargs_from_env(node)
args = [args[0], original_shape, args[1]]
args = args[:1] + (original_shape,) + args[1:]
return self.call_function(node.target, args, node.kwargs)
if node.target in [
if node.target in (
torch.ops.nvprims.native_batch_norm,
torch.ops.nvprims.native_batch_norm.default,
]:
):
args, kwargs = self.fetch_args_kwargs_from_env(node)
assert len(args) == 8
training = args[5]
args6_end = tuple(map(_to_nvfuser_constant, args[6:]))
args6_end = tuple(_to_nvfuser_constant(arg) for arg in args[6:])
args = args[:5] + (training,) + args6_end
return node.target.impl_nvfuser(fd, *args, **kwargs)
@ -209,7 +210,7 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
if target == operator.getitem:
assert isinstance(args[0], tuple)
return target(*args, **kwargs)
args = tuple(map(_to_nvfuser_constant, args))
args = tuple(_to_nvfuser_constant(arg) for arg in args)
target = target.impl_nvfuser
args = (fd,) + args
return target(*args, **kwargs)
@ -239,7 +240,9 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
return arg
# Transforms graph to call nvfuser lowerings
nv_args = tuple(map(templates_to_nvfuser_inputs, nv_args_templates))
nv_args = tuple(
templates_to_nvfuser_inputs(nv_arg) for nv_arg in nv_args_templates
)
out = FusionInterpreter(gm).run(*nv_args)
flat_out, unflatten_spec = tree_flatten(out)

View File

@ -24,7 +24,7 @@ def prepare_pt2e(
qconfig_mapping: QConfigMapping,
example_inputs: Tuple[Any, ...],
backend_config: BackendConfig,
):
) -> GraphModule:
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
@ -50,7 +50,7 @@ def prepare_pt2e(
def prepare_pt2e_quantizer(
model: GraphModule,
quantizer: Quantizer,
):
) -> GraphModule:
node_name_to_scope = _get_node_name_to_scope(model)
# TODO: check qconfig_mapping to make sure conv and bn are both configured
# to be quantized before fusion
@ -66,7 +66,7 @@ def prepare_pt2e_quantizer(
def prepare_qat_pt2e_quantizer(
model: GraphModule,
quantizer: Quantizer,
):
) -> GraphModule:
node_name_to_scope = _get_node_name_to_scope(model)
quantizer.annotate(model)
quantizer.validate(model)
@ -83,7 +83,7 @@ def prepare_qat_pt2e_quantizer(
def convert_pt2e(
model: GraphModule
):
model = _convert_to_reference_decomposed_fx(model) # type: ignore[assignment]
) -> GraphModule:
model = _convert_to_reference_decomposed_fx(model)
model = _fold_conv_bn_qat(model)
return model

View File

@ -431,7 +431,7 @@ def _replace_observer_with_quantize_dequantize_node(
# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph):
def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph) -> None:
call_custom_module_node = node.args[0]
assert isinstance(call_custom_module_node, Node), \
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
@ -479,7 +479,7 @@ def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig)
# run the weight observer
weight_observer_module()
def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph):
def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None:
""" If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
we'll recursively remove the dequantize Node
"""
@ -502,7 +502,7 @@ def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph):
def _get_module_path_and_prefix(
obs_node: Node,
node_name_to_scope: Dict[str, Tuple[str, type]],
node_name_to_qconfig: Dict[str, QConfigAny]):
node_name_to_qconfig: Dict[str, QConfigAny]) -> Tuple[str, str]:
""" Given and observer node, get the `Scope` or the fully qualified name for
the submodule containing the observed node, also return a prefix of "_input"
when the observed node is an input of a F.linear op, and not the output of another
@ -549,7 +549,7 @@ def _get_module_path_and_prefix(
def _insert_dequantize_node(
node: Node,
graph: Graph):
graph: Graph) -> None:
""" Inserts dequantize node for `node` in `graph`
"""
with graph.inserting_after(node):
@ -578,7 +578,7 @@ def convert_standalone_module(
modules: Dict[str, torch.nn.Module],
model: torch.fx.GraphModule,
is_reference: bool,
backend_config: Optional[BackendConfig]):
backend_config: Optional[BackendConfig]) -> None:
""" Converts a observed standalone module to a quantized standalone module by calling
the fx convert api, currently using the same `is_reference` flag as parent, but we may
changing this behavior in the future (e.g. separating quantization and lowering for
@ -641,7 +641,7 @@ def convert_weighted_module(
observed_node_names: Set[str],
node_name_to_qconfig: Dict[str, QConfigAny],
backend_config: BackendConfig,
is_decomposed: bool = False):
is_decomposed: bool = False) -> None:
""" Convert a weighted module to reference quantized module in the model
If the QConfig of a QAT module is not set, the module will still be converted to
a float module.
@ -746,7 +746,7 @@ def convert_weighted_module(
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, ref_qmodule)
def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph):
def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph) -> None:
"""
Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
@ -768,7 +768,7 @@ def convert_custom_module(
graph: Graph,
modules: Dict[str, torch.nn.Module],
custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]],
statically_quantized_custom_module_nodes: Set[Node]):
statically_quantized_custom_module_nodes: Set[Node]) -> None:
""" Converts an observed custom module to a quantized custom module based on
`custom_module_class_mapping`
For static quantization, we'll also remove the previous `dequantize` node and
@ -853,7 +853,7 @@ def convert(
_remove_qconfig_flag: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
is_decomposed: bool = False) -> torch.nn.Module:
is_decomposed: bool = False) -> GraphModule:
"""
We will convert an observed model (a module with observer calls) to a reference
quantized model, the rule is simple:

View File

@ -1333,7 +1333,7 @@ def insert_observers_for_model(
# if not, we'll reset the target_dtye_info to use the default (float Tensor)
# reset the counters and set of processed_nodes
processed_nodes = set()
processed_nodes: Set[Node] = set()
for node_name, match_res_with_qconfig in node_name_to_match_result_with_qconfig.items():
last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(

View File

@ -28,7 +28,9 @@ from .fx.utils import get_skipped_module_name_and_classes
from .qconfig_mapping import QConfigMapping
def attach_preserved_attrs_to_model(
model: Union[GraphModule, torch.nn.Module], preserved_attrs: Dict[str, Any]):
model: Union[GraphModule, torch.nn.Module],
preserved_attrs: Dict[str, Any],
) -> None:
""" Store preserved attributes to the model.meta so that it can be preserved during deepcopy
"""
model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment]
@ -47,7 +49,7 @@ def _check_is_graph_module(model: torch.nn.Module) -> None:
+ "sure to follow the tutorials."
)
def _attach_meta_to_node_if_not_exist(model: GraphModule):
def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None:
""" Attach meta field to all nodes of the graph if it does not exist,
meta field is a field stores some meta information about the node, such
as dtype and shape information for output of the node, this only exists
@ -503,7 +505,7 @@ def _convert_fx(
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
is_decomposed: bool = False,
) -> torch.nn.Module:
) -> GraphModule:
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
"""
if convert_custom_config is None:
@ -540,7 +542,7 @@ def convert_fx(
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> torch.nn.Module:
) -> GraphModule:
r""" Convert a calibrated or trained model to a quantized model
Args:
@ -607,7 +609,7 @@ def convert_to_reference_fx(
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> torch.nn.Module:
) -> GraphModule:
r""" Convert a calibrated or trained model to a reference quantized model,
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
reference quantized model is a standard representation of a quantized model provided
@ -656,7 +658,7 @@ def _convert_to_reference_decomposed_fx(
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> torch.nn.Module:
) -> GraphModule:
r""" Convert a calibrated or trained model to a reference quantized model, with
decomposed representation for quantized Tensor
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
@ -708,7 +710,7 @@ def _convert_standalone_module_fx(
graph_module: GraphModule,
is_reference: bool = False,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
) -> torch.nn.Module:
) -> GraphModule:
r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
and convert it to a quantized model

View File

@ -124,7 +124,7 @@ class _lazy_property_and_property(lazy_property, property):
property.__init__(self, wrapped)
def tril_matrix_to_vec(mat, diag=0):
def tril_matrix_to_vec(mat: torch.Tensor, diag: int = 0) -> torch.Tensor:
r"""
Convert a `D x D` matrix or a batch of matrices into a (batched) vector
which comprises of lower triangular elements from the matrix in row order.
@ -138,7 +138,7 @@ def tril_matrix_to_vec(mat, diag=0):
return vec
def vec_to_tril_matrix(vec, diag=0):
def vec_to_tril_matrix(vec: torch.Tensor, diag: int = 0) -> torch.Tensor:
r"""
Convert a vector or a batch of vectors into a batched `D x D`
lower triangular matrix containing elements from the vector in row order.
@ -149,7 +149,7 @@ def vec_to_tril_matrix(vec, diag=0):
if not torch._C._get_tracing_state() and (round(n) - n > eps):
raise ValueError(f'The size of last dimension is {vec.shape[-1]} which cannot be expressed as ' +
'the lower triangular part of a square D x D matrix.')
n = torch.round(n).long() if isinstance(n, torch.Tensor) else round(n)
n = round(n.item()) if isinstance(n, torch.Tensor) else round(n)
mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n)))
arange = torch.arange(n, device=vec.device)
tril_mask = arange < arange.view(-1, 1) + (diag + 1)

View File

@ -194,12 +194,12 @@ class CapabilityBasedPartitioner:
for id, partition in partitions_by_id.items():
compute_node_count = 0
for node in partition.nodes:
if node.op == "call_function" and \
_get_qualified_name(node.target) not in non_compute_ops: # type: ignore[arg-type]
compute_node_count += 1
if node.op == "call_function" and \
_get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
compute_node_count += 1
if node.op == "call_function":
assert callable(node.target)
if _get_qualified_name(node.target) not in non_compute_ops:
compute_node_count += 1
if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
compute_node_count += 1
if compute_node_count <= 1:
partitions_to_remove.append(id)
for id in partitions_to_remove:

View File

@ -1462,14 +1462,14 @@ def embedding_bag(
def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type):
unique_indices = g.op("Unique", indices)
partial_weight = g.op("Gather", weight, unique_indices)
norm_type = int(norm_type)
if norm_type == 1:
norm_i = int(norm_type)
if norm_i == 1:
norm_type = "ReduceL1"
elif norm_type == 2:
elif norm_i == 2:
norm_type = "ReduceL2"
else:
raise errors.SymbolicValueError(
f"Unsupported: ONNX export of embedding_renorm with norm: {norm_type}. "
f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. "
"Only 1. and 2. are supported.",
weight,
)

View File

@ -2025,8 +2025,9 @@ def _write_ninja_file_to_build_library(path,
cuda_flags += extra_cuda_cflags
if not any(flag.startswith('-std=') for flag in cuda_flags):
cuda_flags.append('-std=c++17')
if os.getenv("CC") is not None:
cuda_flags = ['-ccbin', os.getenv("CC")] + cuda_flags
cc_env = os.getenv("CC")
if cc_env is not None:
cuda_flags = ['-ccbin', cc_env] + cuda_flags
else:
cuda_flags = None