remove unused type: ignore directives (#60006)

Summary:
During development it is common practice to put `type: ignore` comments on lines that are correct, but `mypy` doesn't recognize this. This often stems from the fact, that the used `mypy` version wasn't able to handle the used pattern.

With every new release `mypy` gets better at handling complex code. In addition to fix all the previously accepted but now failing patterns, we should also revisit all `type: ignore` comments to see if they are still needed or not. Fortunately, we don't need to do it manually: by adding `warn_unused_ignores = True` to the configuration, `mypy` will error out in case it encounters an `type: ignore` that is no longer needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60006

Reviewed By: jbschlosser, malfet

Differential Revision: D29133237

Pulled By: albanD

fbshipit-source-id: 41e82edc5cd5affa7ccedad044b59b94dad4425a
This commit is contained in:
Philip Meier
2021-06-18 07:22:22 -07:00
committed by Facebook GitHub Bot
parent 7c29ca7f2b
commit d5988c5eca
37 changed files with 108 additions and 93 deletions

View File

@ -26,7 +26,7 @@ def fn(base: str) -> str:
with open(Path(__file__).parent.parent.parent / fn('.'), "r") as f:
contents = f.read()
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
yaml = ruamel.yaml.YAML()
yaml.preserve_quotes = True
yaml.width = 1000
yaml.boolean_representation = ['False', 'True']

View File

@ -39,7 +39,7 @@ if args.aten_root:
sys.path.insert(0, os.path.join(args.aten_root, '..'))
from tools.codegen.code_template import CodeTemplate as CT
else:
from tools.codegen.code_template import CodeTemplate as CT # type: ignore[import,no-redef]
from tools.codegen.code_template import CodeTemplate as CT
OP_TEMPLATE = CT.from_file(
os.path.join(args.template_dir, 'aten_op_template.h'))

View File

@ -8,7 +8,7 @@ import os
import tempfile
import shutil
from caffe2.distributed.python import StoreHandlerTimeoutError # type: ignore[import]
from caffe2.distributed.python import StoreHandlerTimeoutError
from caffe2.distributed.store_ops_test_util import StoreOpsTests
from caffe2.python import core, workspace, dyndep
from caffe2.python.test_util import TestCase

View File

@ -6,7 +6,7 @@
import os
import uuid
from caffe2.distributed.python import StoreHandlerTimeoutError # type: ignore[import]
from caffe2.distributed.python import StoreHandlerTimeoutError
from caffe2.distributed.store_ops_test_util import StoreOpsTests
from caffe2.python import core, workspace, dyndep
from caffe2.python.test_util import TestCase

View File

@ -19,6 +19,7 @@ disallow_any_unimported = True
# Across versions of mypy, the flags toggled by --strict vary. To ensure
# we have reproducible type check, we instead manually specify the flags
warn_unused_configs = True
warn_unused_ignores = True
disallow_any_generics = True
disallow_subclassing_any = True
disallow_untyped_calls = True

View File

@ -6,6 +6,7 @@ plugins = mypy_plugins/check_mypy_version.py
cache_dir = .mypy_cache/normal
warn_unused_configs = True
warn_unused_ignores = True
warn_redundant_casts = True
show_error_codes = True
show_column_numbers = True
@ -95,6 +96,19 @@ ignore_errors = True
[mypy-torch.overrides]
ignore_errors = True
#
# Files with 'type: ignore' comments that are needed if checked with mypy-strict.ini
#
[mypy-tools.render_junit]
warn_unused_ignores = False
[mypy-tools.generate_torch_version]
warn_unused_ignores = False
[mypy-tools.stats_utils.s3_stat_parser]
warn_unused_ignores = False
#
# Adding type annotations to caffe2 is probably not worth the effort
# only work on this if you have a specific reason for it, otherwise

View File

@ -30,7 +30,7 @@ class TestFuture(TestCase):
f = Future()
f.set_exception(value_error)
with self.assertRaisesRegex(ValueError, "Intentional"):
f.value() # type: ignore[attr-defined]
f.value()
def cb(fut):
fut.value()

View File

@ -743,7 +743,7 @@ class TestAssert(TestCase):
# data can be passed without errors
x = torch.randn(4, 4).fill_(1.0)
ms(x)
with self.assertRaisesRegex(torch.jit.Error, "foo"): # type: ignore[type-var]
with self.assertRaisesRegex(torch.jit.Error, "foo"):
ms(torch.tensor([False], dtype=torch.bool))

View File

@ -126,7 +126,10 @@ binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'tru
'iadd', 'iand', 'idiv', 'ilshift', 'imul',
'ior', 'irshift', 'isub', 'ixor', 'ifloordiv', 'imod', # inplace ops
)
comparison_ops = ('eq', 'ne', 'ge', 'gt', 'lt', 'le')
symmetric_comparison_ops = ('eq', 'ne')
asymmetric_comparison_ops = ('ge', 'gt', 'lt', 'le')
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
unary_ops = ('neg', 'abs', 'invert')
to_py_type_ops = ('bool', 'float', 'complex', 'long', 'index', 'int', 'nonzero')
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
@ -145,8 +148,11 @@ def sig_for_ops(opname: str) -> List[str]:
if name in binary_ops:
return ['def {}(self, other: Any) -> Tensor: ...'.format(opname)]
elif name in comparison_ops:
# unsafe override https://github.com/python/mypy/issues/5704
return ['def {}(self, other: Any) -> Tensor: ... # type: ignore[override]'.format(opname)]
sig = 'def {}(self, other: Any) -> Tensor: ...'.format(opname)
if name in symmetric_comparison_ops:
# unsafe override https://github.com/python/mypy/issues/5704
sig += ' # type: ignore[override]'
return [sig]
elif name in unary_ops:
return ['def {}(self) -> Tensor: ...'.format(opname)]
elif name in to_py_type_ops:

View File

@ -12,7 +12,7 @@ except ImportError:
)
try:
import rich # type: ignore[import]
import rich
except ImportError:
print("rich not found, for color output use 'pip install rich'")

View File

@ -309,7 +309,7 @@ class ShardedTensor(object):
def _parse_and_validate_remote_device(self, device):
on, local_device = _parse_remote_device(device) # type: ignore[arg-type]
on, local_device = _parse_remote_device(device)
# Validate rank.
if isinstance(on, int) and (on < 0 or on >= dist.get_world_size(self._process_group)):

View File

@ -1591,7 +1591,7 @@ def all_gather_object(object_list, obj, group=None):
all_gather(output_tensors, input_tensor, group=group)
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8) # type:ignore[call-overload]
tensor = tensor.type(torch.uint8)
if tensor.device != torch.device("cpu"):
tensor = tensor.cpu()
tensor_size = object_size_list[i]
@ -1695,7 +1695,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
if my_rank != dst:
return
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8) # type: ignore[call-overload]
tensor = tensor.type(torch.uint8)
tensor_size = object_size_list[i]
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
@ -1790,7 +1790,7 @@ def broadcast_object_list(object_list, src=0, group=None):
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8) # type: ignore[call-overload]
obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
offset += obj_size

View File

@ -143,7 +143,7 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
# see the explanation in the except clause below.
for is_server in [is_host, False]:
try:
store = TCPStore( # type: ignore[call-arg]
store = TCPStore(
host, port, is_master=is_server, timeout=timedelta(seconds=read_timeout)
)

View File

@ -44,7 +44,7 @@ class ElasticDistributedSampler(DistributedSampler):
self.start_index = start_index
self.num_samples = int(
math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas) # type: ignore[arg-type]
math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas)
)
self.total_size = self.num_samples * self.num_replicas
@ -53,7 +53,7 @@ class ElasticDistributedSampler(DistributedSampler):
g = torch.Generator()
g.manual_seed(self.epoch)
indices = (
torch.randperm(len(self.dataset) - self.start_index, generator=g) # type: ignore[arg-type]
torch.randperm(len(self.dataset) - self.start_index, generator=g)
.add(self.start_index)
.tolist()
)

View File

@ -12,8 +12,8 @@ from typing import Any, Callable, Dict, List, Optional, Union, cast, Tuple
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic import events, metrics
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState # type: ignore[import]
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent # type: ignore[import]
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from torch.distributed.elastic.multiprocessing import Std
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError, record
from torch.distributed.elastic.rendezvous import RendezvousParameters

View File

@ -395,10 +395,10 @@ class _RemoteModule(nn.Module):
):
_raise_not_supported(self.named_modules.__name__)
def train(self: T, mode: bool = True) -> T: # type: ignore[return]
def train(self: T, mode: bool = True) -> T:
return self.module_rref.rpc_sync().train() # type: ignore[operator, union-attr]
def eval(self: T) -> T: # type: ignore[return]
def eval(self: T) -> T:
return self.module_rref.rpc_sync().eval() # type: ignore[operator, union-attr]
def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return]
@ -413,7 +413,7 @@ class _RemoteModule(nn.Module):
def extra_repr(self) -> str: # type: ignore[return]
_raise_not_supported(self.extra_repr.__name__)
def _prepare_init(self, remote_device: str) -> bool: # type: ignore[return]
def _prepare_init(self, remote_device: str) -> bool:
"""
Prepares the initializaiton and returns whether to enable automatically moving CPU tensors to CUDA devices.
"""
@ -639,7 +639,7 @@ class RemoteModule(_RemoteModule):
args: Tuple = None,
kwargs: Dict[str, Any] = None,
):
super().__init__(remote_device, module_cls, args, kwargs) # type: ignore[arg-type]
super().__init__(remote_device, module_cls, args, kwargs)
def _remote_module_receiver(
@ -651,7 +651,7 @@ def _remote_module_receiver(
serialized_remote_module = _SerializedRemoteModule._make(
remote_module_pickled_attrs
)
m = object.__new__(RemoteModule) # type: ignore[attr-defined]
m = object.__new__(RemoteModule)
m.__dict__.update(serialized_remote_module._asdict())
# Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method.
@ -675,10 +675,10 @@ def _remote_module_reducer(remote_module):
# Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method.
if k == "module_rref":
pickled_attrs[k] = v._serialize()
elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES: # type: ignore[attr-defined]
elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES:
pickled_attrs[k] = v
# Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING: # type: ignore[attr-defined]
elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING:
print(
"The new attribute ``{}`` of RemoteModule is ignored during RPC pickling. "
"To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. "

View File

@ -60,8 +60,7 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
result = urlparse(url)
if rank != -1 or world_size != -1:
query_dict: Dict[str, Union[int, str]] = dict(
# mypy doesn't allow dict() to accept List of values (#257)
pair.split("=") for pair in filter(None, result.query.split("&")) # type: ignore[arg-type, misc]
pair.split("=") for pair in filter(None, result.query.split("&"))
)
assert (
"rank" not in query_dict and "world_size" not in query_dict

View File

@ -101,7 +101,7 @@ class lazy_property(object):
"""
def __init__(self, wrapped):
self.wrapped = wrapped
update_wrapper(self, wrapped) # type: ignore[arg-type]
update_wrapper(self, wrapped)
def __get__(self, instance, obj_type=None):
if instance is None:

View File

@ -74,7 +74,7 @@ class NormalizeArgs(Transformer):
args, # type: ignore[arg-type]
kwargs,
arg_types, # type: ignore[arg-type]
kwarg_types, # type: ignore[arg-type]
kwarg_types,
self.normalize_to_only_use_kwargs,
)
if new_args_and_kwargs:
@ -93,7 +93,7 @@ class NormalizeArgs(Transformer):
self.module,
target,
args, # type: ignore[arg-type]
kwargs, # type: ignore[arg-type]
kwargs,
self.normalize_to_only_use_kwargs,
)
if new_args_and_kwargs:

View File

@ -256,7 +256,7 @@ class _MinimizerBase:
if node in selected_nodes:
node.tag = "minimize"
elif any(
n.tag in {"minimize", "main_1"} # type: ignore[attr-defined]
n.tag in {"minimize", "main_1"}
for n in node.all_input_nodes
if n.op in CALLABLE_NODE_OPS
):

View File

@ -10,8 +10,8 @@ class Partition:
self.outputs: Dict[str, None] = {}
self.partitions_dependent_on: Dict[str, None] = {}
self.partition_dependents: Dict[str, None] = {}
self.graph : torch.fx.graph.Graph = torch.fx.graph.Graph() # type: ignore[attr-defined, name-defined]
self.environment : Dict[torch.fx.node.Node, torch.fx.node.Node] = {} # type: ignore[name-defined]
self.graph : torch.fx.graph.Graph = torch.fx.graph.Graph()
self.environment : Dict[torch.fx.node.Node, torch.fx.node.Node] = {}
self.targets : Dict[str, Any] = {}
def __repr__(self) -> str:
@ -26,12 +26,12 @@ class Partition:
def split_module(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int], # type: ignore[name-defined]
split_callback: Callable[[torch.fx.node.Node], int],
):
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {} # type: ignore[name-defined]
orig_nodes: Dict[str, torch.fx.node.Node] = {}
def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optional[torch.fx.node.Node]): # type: ignore[name-defined] # noqa: B950
def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name:
@ -56,7 +56,7 @@ def split_module(
if node.op in ["placeholder", "get_attr"]:
continue
if node.op == 'output':
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) # type: ignore[attr-defined]
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
continue
partition_name = str(split_callback(node))
@ -68,8 +68,8 @@ def split_module(
partition.node_names.append(node.name)
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) # type: ignore[attr-defined]
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # type: ignore[attr-defined] # noqa: B950
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions : List[str] = []
@ -104,8 +104,8 @@ def split_module(
# swap out old graph nodes in kw/args with references to new nodes in this submodule
environment = partition.environment
gathered_args = torch.fx.graph.map_arg(node.args, lambda n : environment[n]) # type: ignore[attr-defined]
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n : environment[n]) # type: ignore[attr-defined]
gathered_args = torch.fx.graph.map_arg(node.args, lambda n : environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n : environment[n])
if node.op not in ['call_module', 'get_attr']:
target = node.target
@ -128,9 +128,9 @@ def split_module(
partition.environment[node] = new_node
# Set up values to construct base module
base_mod_env : Dict[str, torch.fx.node.Node] = {} # type: ignore[name-defined]
base_mod_graph : torch.fx.graph.Graph = torch.fx.graph.Graph() # type: ignore[attr-defined, name-defined]
base_mod_attrs : Dict[str, torch.fx.graph_module.GraphModule] = {} # type: ignore[name-defined]
base_mod_env : Dict[str, torch.fx.node.Node] = {}
base_mod_graph : torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs : Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
base_mod_env[node.name] = base_mod_graph.placeholder(node.name)
@ -159,14 +159,14 @@ def split_module(
# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, partition.graph) # type: ignore[attr-defined] # noqa: B950
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, partition.graph) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
if len(partition.outputs) > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val) # type: ignore[attr-defined]
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
@ -174,6 +174,6 @@ def split_module(
for node in m.graph.nodes:
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n : base_mod_env[n.name])) # type: ignore[attr-defined] # noqa: B950
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n : base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) # type: ignore[attr-defined]
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)

View File

@ -53,7 +53,7 @@ if _IS_MONKEYTYPE_INSTALLED:
# and create a dictionary of all the types
# for arguments.
records = self.trace_records[qualified_name]
all_args = defaultdict(set) # type: ignore[var-annotated]
all_args = defaultdict(set)
for record in records:
for arg, arg_type in record.arg_types.items():
all_args[arg].add(arg_type)
@ -123,4 +123,4 @@ else:
def __init__(self):
pass
monkeytype_trace = None # type: ignore[assignment] # noqa: F811
monkeytype_trace = None # noqa: F811

View File

@ -45,9 +45,9 @@ def _load_for_lite_interpreter(f, map_location=None):
map_location = validate_map_location(map_location)
if isinstance(f, str) or isinstance(f, pathlib.Path):
cpp_module = torch._C._load_for_lite_interpreter(f, map_location) # type: ignore[attr-defined]
cpp_module = torch._C._load_for_lite_interpreter(f, map_location)
else:
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(f.read(), map_location) # type: ignore[attr-defined]
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(f.read(), map_location)
return LiteScriptModule(cpp_module)
@ -102,9 +102,9 @@ def _get_model_bytecode_version(f_input) -> int:
raise ValueError(f"The provided filename {f_input} is a directory")
if (isinstance(f_input, str) or isinstance(f_input, pathlib.Path)):
return torch._C._get_model_bytecode_version(str(f_input)) # type: ignore[attr-defined]
return torch._C._get_model_bytecode_version(str(f_input))
else:
return torch._C._get_model_bytecode_version_from_buffer(f_input.read()) # type: ignore[attr-defined]
return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
def _backport_for_mobile(f_input, f_output, to_version):
r"""
@ -124,9 +124,9 @@ def _backport_for_mobile(f_input, f_output, to_version):
if ((isinstance(f_input, str) or isinstance(f_input, pathlib.Path)) and (
isinstance(f_output, str) or isinstance(f_output, pathlib.Path))):
return torch._C._backport_for_mobile(str(f_input), str(f_output), to_version) # type: ignore[attr-defined]
return torch._C._backport_for_mobile(str(f_input), str(f_output), to_version)
else:
return torch._C._backport_for_mobile_from_buffer(f_input.read(), str(f_output), to_version) # type: ignore[attr-defined]
return torch._C._backport_for_mobile_from_buffer(f_input.read(), str(f_output), to_version)
def _backport_for_mobile_to_buffer(f_input, to_version):
r"""
@ -142,9 +142,9 @@ def _backport_for_mobile_to_buffer(f_input, to_version):
raise ValueError(f"The provided filename {f_input} is a directory")
if (isinstance(f_input, str) or isinstance(f_input, pathlib.Path)):
return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version) # type: ignore[attr-defined]
return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version)
else:
return torch._C._backport_for_mobile_from_buffer_to_buffer(f_input.read(), to_version) # type: ignore[attr-defined]
return torch._C._backport_for_mobile_from_buffer_to_buffer(f_input.read(), to_version)
def _get_model_ops_and_info(f_input):
r"""
@ -182,6 +182,6 @@ def _get_model_ops_and_info(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if (isinstance(f_input, str) or isinstance(f_input, pathlib.Path)):
return torch._C._get_model_ops_and_info(str(f_input)) # type: ignore[attr-defined]
return torch._C._get_model_ops_and_info(str(f_input))
else:
return torch._C._get_model_ops_and_info(f_input.read()) # type: ignore[attr-defined]
return torch._C._get_model_ops_and_info(f_input.read())

View File

@ -201,7 +201,7 @@ class _ConvNd(nn.Module):
'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
# the __init__ call used is the one from derived classes and not the one from _ConvNd
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
mod.stride, mod.padding, mod.dilation, mod.groups,
mod.bias is not None, mod.padding_mode)
qconv.set_weight_bias(qweight, mod.bias)

View File

@ -90,7 +90,7 @@ class _SpectralNorm(Module):
# This power iteration produces approximations of `u` and `v`.
self._u = F.normalize(torch.mv(weight_mat, self._v), # type: ignore[has-type]
dim=0, eps=self.eps, out=self._u) # type: ignore[has-type]
self._v = F.normalize(torch.mv(weight_mat.t(), self._u), # type: ignore[has-type]
self._v = F.normalize(torch.mv(weight_mat.t(), self._u),
dim=0, eps=self.eps, out=self._v) # type: ignore[has-type]
# See above on why we need to clone
self._u = self._u.clone(memory_format=torch.contiguous_format)

View File

@ -366,7 +366,7 @@ class PackageImporter(Importer):
return
# Set the module as an attribute on its parent.
parent_module = self.modules[parent]
if parent_module.__loader__ is self: # type: ignore[union-attr]
if parent_module.__loader__ is self:
setattr(parent_module, name.rpartition(".")[2], module)
# note: copied from cpython's import code, with call to create module replaced with _make_module

View File

@ -394,7 +394,7 @@ def convert(model: GraphModule, is_reference: bool = False,
# for non-standalone module, since _standalone_module_output_quantized_idxs
# is only available in observed standalone module
if is_observed_standalone_module_node:
out_quant_idxs = modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore[operator] # noqa: B950
out_quant_idxs = modules[node.target]._standalone_module_output_quantized_idxs.tolist() # noqa: B950
assert len(out_quant_idxs) <= 1, "Currently standalone only support one output"
quantized = 0 in out_quant_idxs

View File

@ -425,7 +425,7 @@ def maybe_insert_input_observers_for_node(
# assign the new args and kwargs to the node, inplace
node.args = tuple(new_args)
node.kwargs = new_kwargs # type: ignore[assignment]
node.kwargs = new_kwargs
def maybe_insert_input_equalization_observers_for_node(
node: Node,
@ -946,7 +946,7 @@ def run_prepare_fx_on_standalone_modules(
get_standalone_module_configs(
root_node, modules, prepare_custom_config_dict, qconfig)
standalone_module = modules[root_node.target] # type: ignore[index]
standalone_module = modules[root_node.target]
prepare = \
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined]
observed_standalone_module = \
@ -959,7 +959,7 @@ def run_prepare_fx_on_standalone_modules(
parent_name, name = _parent_name(root_node.target)
setattr(modules[parent_name], name,
observed_standalone_module)
modules[root_node.target] = observed_standalone_module # type: ignore[index]
modules[root_node.target] = observed_standalone_module
def save_state(
observed: GraphModule,

View File

@ -508,7 +508,7 @@ class MultiProcessTestCase(TestCase):
if sys.platform != 'win32' and sys.platform != 'darwin':
# Register signal handler to dump stack traces on FATALs.
# Windows and MacOS do not support the signal handlers.
torch._C._set_print_stack_traces_on_fatal_signal(True) # type: ignore[attr-defined]
torch._C._set_print_stack_traces_on_fatal_signal(True)
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
# We're retrieving a corresponding test and executing it.

View File

@ -3292,24 +3292,24 @@ def sample_inputs_fmod_remainder(op_info, device, dtype, requires_grad, *, autod
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
if autodiffed:
samples = ( # type: ignore[assignment]
samples = (
((S, S, S), 1.5, False),
((), 1.5, False),
)
else:
cases = ( # type: ignore[assignment]
cases = (
((S, S, S), (), False),
((S, S, S), (S, S, S), False),
((S, S, S), (S,), False),
)
# Sample inputs with scalars as torch tensors
cases_with_tensor_scalar = ( # type: ignore[assignment]
cases_with_tensor_scalar = (
((), torch.tensor(1, dtype=dtype, device=device, requires_grad=False), False),
)
# Sample inputs with broadcasting
cases_with_broadcasting = ( # type: ignore[assignment]
cases_with_broadcasting = (
((S,), (S, S, S), True),
((S, 1, S), (S, S, S), True),
((), (S, S, S), True),
@ -3978,7 +3978,7 @@ def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=Fals
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
if list_args:
cases = ( # type: ignore[assignment]
cases = (
((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],)),
((S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2),),
((S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], -2),)

View File

@ -1440,7 +1440,7 @@ class TestCase(expecttest.TestCase):
super().assertEqual(x, y, msg=msg)
def assertNotEqual(self, x, y, msg: Optional[str] = None, *, # type: ignore[override]
atol: Optional[float] = None, rtol: Optional[float] = None, **kwargs) -> None: # type: ignore[override]
atol: Optional[float] = None, rtol: Optional[float] = None, **kwargs) -> None:
with self.assertRaises(AssertionError, msg=msg):
self.assertEqual(x, y, msg, atol=atol, rtol=rtol, **kwargs)

View File

@ -6,7 +6,7 @@ from typing import Tuple
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.distributed.rpc import _rref_context_get_debug_info # type: ignore[attr-defined]
from torch.distributed.rpc import _rref_context_get_debug_info
from torch.testing._internal.common_utils import FILE_SCHEMA

View File

@ -265,11 +265,9 @@ class DataLoader(Generic[T_co]):
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
# Cannot statically verify that dataset is Sized
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
sampler = RandomSampler(dataset, generator=generator)
else:
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler

View File

@ -295,7 +295,6 @@ class ChainDataset(IterableDataset):
total = 0
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
# Cannot verify that all self.datasets are Sized
total += len(d)
return total
@ -338,7 +337,7 @@ def random_split(dataset: Dataset[T], lengths: Sequence[int],
generator (Generator): Generator used for the random permutation.
"""
# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
if sum(lengths) != len(dataset):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = randperm(sum(lengths), generator=generator).tolist()

View File

@ -78,17 +78,15 @@ class DistributedSampler(Sampler[T_co]):
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
# `type:ignore` is required because Dataset cannot provide a default __len__
# see NOTE in pytorch/torch/utils/data/sampler.py
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas) / self.num_replicas
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
@ -98,9 +96,9 @@ class DistributedSampler(Sampler[T_co]):
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
indices = list(range(len(self.dataset)))
if not self.drop_last:
# add extra samples to make it evenly divisible

View File

@ -334,10 +334,10 @@ def get_inline_skeleton():
import importlib.resources
skeleton = importlib.resources.read_text(__package__, "skeleton.html") # type: ignore[attr-defined]
js_code = importlib.resources.read_text(__package__, "code.js") # type: ignore[attr-defined]
skeleton = importlib.resources.read_text(__package__, "skeleton.html")
js_code = importlib.resources.read_text(__package__, "code.js")
for js_module in ["preact", "htm"]:
js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs") # type: ignore[attr-defined]
js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs")
js_url = "data:application/javascript," + urllib.parse.quote(js_lib)
js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url)
skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code)

View File

@ -837,7 +837,7 @@ class SummaryWriter(object):
metadata, label_img, fs, subdir, global_step, tag)
self._projector_config.embeddings.extend([embedding_info])
from google.protobuf import text_format # type: ignore[attr-defined]
from google.protobuf import text_format
config_pbtxt = text_format.MessageToString(self._projector_config)
write_pbtxt(self._get_file_writer().get_logdir(), config_pbtxt)