Delete a bunch of type-ignores (#113990)

* Replaced `ignore[import]` by mypy config file entries
* Removed a bunch of ignores around previously-fixed attr-defined /
  call-arg issues
* Fixed some invalid / undefined types; added a few more type-ignores to
  squelch the downstream errors this exposed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113990
Approved by: https://github.com/eellison, https://github.com/Skylion007
ghstack dependencies: #113979
This commit is contained in:
Jez Ng
2023-11-17 15:33:34 -08:00
committed by PyTorch MergeBot
parent 47220bc72a
commit 4667e20b3f
24 changed files with 72 additions and 59 deletions

View File

@ -24,6 +24,12 @@ files =
# understood by mypy.
python_version = 3.11
[mypy-colorama.*]
ignore_missing_imports = True
[mypy-cutlass_library.*]
ignore_missing_imports = True
[mypy-deeplearning.*]
ignore_missing_imports = True
@ -33,11 +39,17 @@ ignore_missing_imports = True
[mypy-einops.*]
ignore_missing_imports = True
[mypy-libfb.*]
ignore_missing_imports = True
# sympy is too dynamic, hard to type properly
[mypy-sympy.*]
ignore_missing_imports = True
follow_imports = skip
[mypy-torch.*.fb.*]
ignore_missing_imports = True
# FIXME: importing this creates lots of type errors
[mypy-torch._dynamo.variables.*]
follow_imports = skip

View File

@ -34,7 +34,7 @@ def openxla_eval_boxed(model, fake_tensor_inputs):
def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
try:
import torch_xla.core.dynamo_bridge as bridge # type: ignore[import]
import torch_xla.core.dynamo_bridge as bridge
except ImportError as e:
raise ImportError(
"Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla"

View File

@ -34,7 +34,7 @@ inductor_config = import_module("torch._inductor.config")
use_buck = inductor_config.is_fbcode()
if use_buck:
import libfb.py.build_info # type: ignore[import]
import libfb.py.build_info
extra_deps = []

View File

@ -157,7 +157,7 @@ DONT_WRAP_FILES = {
def _debug_get_cache_entry_list(
code: Union[types.CodeType, Callable[..., Any]]
) -> List[CacheEntry]: # type: ignore[valid-type]
) -> List[CacheEntry]:
"""
Given a code object or a callable object, retrieve the cache entries
stored in this code.

View File

@ -53,7 +53,7 @@ class DynamoCallbackFn(Protocol):
def __call__(
self,
frame: DynamoFrameType,
cache_entry: Optional[CacheEntry], # type: ignore[valid-type]
cache_entry: Optional[CacheEntry],
frame_state: FrameState,
) -> Optional[GuardedCode]:
...

View File

@ -65,7 +65,7 @@ if config.is_fbcode():
from triton.fb import build_paths
from triton.fb.build import _run_build_command
from torch._inductor.fb.utils import ( # type: ignore[import]
from torch._inductor.fb.utils import (
log_global_cache_errors,
log_global_cache_stats,
log_global_cache_vals,

View File

@ -78,7 +78,7 @@ class CUDACPPScheduling(BaseScheduling):
# We can fuse a Pointwise op that depends on the last fused epilogue node
# if any. If there is no epilogue node yet, it needs to depend on the template
# node
node_name = additional_node.get_computed_buffer_name() # type: ignore[attr-defined]
node_name = additional_node.get_computed_buffer_name()
if node_name is None:
return False

View File

@ -84,7 +84,7 @@ class CutlassEVTEpilogueTypeFormatter:
template_output_node_name, evt_type_name
)
with virtualized.V.set_ops_handler(formatter), patch.object( # type: ignore[call-arg]
with virtualized.V.set_ops_handler(formatter), patch.object(
FlexibleLayout, "allow_indexing", True
):
for node in epilogue_nodes:
@ -253,7 +253,7 @@ class CutlassEVTEpilogueArgumentFormatter:
template_output_node_name,
)
with virtualized.V.set_ops_handler(formatter), patch.object( # type: ignore[call-arg]
with virtualized.V.set_ops_handler(formatter), patch.object(
FlexibleLayout, "allow_indexing", True
):
for node in epilogue_nodes:

View File

@ -3,8 +3,8 @@ from ..cutlass_utils import try_import_cutlass
if try_import_cutlass():
import enum
from cutlass_library.library import * # type: ignore[import] # noqa: F401, F403
from cutlass_library.gemm_operation import * # type: ignore[import] # noqa: F401, F403
from cutlass_library.library import * # noqa: F401, F403
from cutlass_library.gemm_operation import * # noqa: F401, F403
# copied / modified from original at
# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658

View File

@ -72,9 +72,9 @@ def try_import_cutlass() -> bool:
os.symlink(cutlass_py_full_path, dst_link)
sys.path.append(tmp_cutlass_py_full_path)
try:
import cutlass_library.generator # type: ignore[import] # noqa: F401
import cutlass_library.library # type: ignore[import] # noqa: F401
import cutlass_library.manifest # type: ignore[import] # noqa: F401
import cutlass_library.generator # noqa: F401
import cutlass_library.library # noqa: F401
import cutlass_library.manifest # noqa: F401
return True
@ -139,8 +139,8 @@ def _gen_ops_cached(arch, version) -> List[Any]:
# Import cutlass python scripts.
assert try_import_cutlass()
import cutlass_library.generator as cutlass_generator # type: ignore[import]
import cutlass_library.manifest as cutlass_manifest # type: ignore[import]
import cutlass_library.generator as cutlass_generator
import cutlass_library.manifest as cutlass_manifest
if arch is None or version is None:
log.error(
@ -184,7 +184,7 @@ def dtype_match(
) -> bool:
# Import cutlass python scripts.
assert try_import_cutlass()
import cutlass_library # type: ignore[import]
import cutlass_library
if torch_dtype == torch.float:
return (

View File

@ -266,7 +266,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
@staticmethod
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib # type: ignore[import]
import cutlass_library.library as cutlass_lib
if torch_layout.stride[-1] == 1:
return cutlass_lib.LayoutType.RowMajor
@ -280,7 +280,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined]
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib # type: ignore[import]
import cutlass_library.library as cutlass_lib
if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
return cutlass_lib.LayoutType.ColumnMajor
@ -303,7 +303,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
@staticmethod
def has_tma_epilogue(op) -> bool:
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib # type: ignore[import]
import cutlass_library.library as cutlass_lib
result = False
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
@ -320,7 +320,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
See https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L283-L285 # noqa: B950
"""
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib # type: ignore[import]
import cutlass_library.library as cutlass_lib
if op.gemm_kind != cutlass_lib.GemmKind.Universal3x:
return False
@ -350,8 +350,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
epilogue_nodes: Optional[List[IRNode]] = None,
) -> Tuple[str, str]:
assert cutlass_utils.try_import_cutlass()
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
import cutlass_library.library as cutlass_lib # type: ignore[import]
import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib
from torch._inductor.codegen.cuda.cutlass_lib_extensions.gemm_operation_extensions import (
EmitGemmUniversal3xInstanceWithEVT,
@ -424,7 +424,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined]
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
import cutlass_library.library as cutlass_lib # type: ignore[import]
import cutlass_library.library as cutlass_lib
# Skip simt kernels
if (
@ -510,8 +510,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
import cutlass_library.library as cutlass_lib # type: ignore[import]
import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib
ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
res: Dict[str, cutlass_gemm_op.GemmOperation] = dict()
@ -645,8 +645,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
)
assert cutlass_utils.try_import_cutlass()
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
import cutlass_library.library as cutlass_lib # type: ignore[import]
import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib
assert isinstance(
op, cutlass_gemm_op.GemmOperation

View File

@ -2530,7 +2530,7 @@ class TritonScheduling(BaseScheduling):
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
with V.set_kernel_handler(kernel): # type: ignore[call-arg]
with V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel()
for node in node_schedule:
@ -2651,7 +2651,7 @@ class TritonScheduling(BaseScheduling):
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
# finalize must be called after adding epilogue above
with V.set_kernel_handler(kernel): # type: ignore[call-arg]
with V.set_kernel_handler(kernel):
# TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
src_code = (
partial_code
@ -2696,7 +2696,7 @@ class TritonScheduling(BaseScheduling):
subkernel,
)
with V.set_kernel_handler(subkernel): # type: ignore[call-arg]
with V.set_kernel_handler(subkernel):
for node in node_schedule:
if node not in (EnableReduction, DisableReduction):
node.mark_run()

View File

@ -1,12 +1,17 @@
import math
from enum import IntEnum
from typing import TYPE_CHECKING
import torch
from . import ir
from .utils import get_dtype_size, sympy_product
from .virtualized import V
if TYPE_CHECKING:
from torch._inductor.scheduler import BaseSchedulerNode
class NCCL_COLL(IntEnum):
ALL_REDUCE = 0
@ -33,7 +38,7 @@ def get_gpu_type() -> NVIDIA_GPU_TYPE:
return NVIDIA_GPU_TYPE.AMPERE
def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL: # type: ignore[name-defined]
def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL:
if isinstance(snode.node, (ir.AllReduce, ir.AllReduceCoalesced)):
return NCCL_COLL.ALL_REDUCE
elif isinstance(
@ -136,7 +141,7 @@ llMaxBws = torch.tensor(
)
def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: # type: ignore[name-defined]
def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).
@ -158,7 +163,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: # ty
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
num_gpus_per_node = 8
_, _, group_size = snode.node.constant_args
_, _, group_size = snode.node.constant_args # type: ignore[attr-defined]
nNodes = math.ceil(group_size / num_gpus_per_node)
nRanks = group_size # this is total # of gpus globally that participate in this collective op

View File

@ -56,7 +56,7 @@ from .utils import get_dtype_size, has_incompatible_cudagraph_ops
from .virtualized import V
if config.is_fbcode():
from torch._inductor.fb.utils import time_and_log # type: ignore[import]
from torch._inductor.fb.utils import time_and_log
else:
# no-op decorator
def time_and_log(attr: str):
@ -208,7 +208,7 @@ def count_bytes_inner(
post_grad_passes(gm, False)
graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs): # type: ignore[call-arg]
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs):
graph.run(*example_inputs)
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
metrics.num_bytes_accessed += num_bytes
@ -610,7 +610,7 @@ def fx_codegen_and_compile(
for out in graph.graph_outputs:
if hasattr(out, "layout"):
output_strides.append(
tuple( # type: ignore[arg-type]
tuple(
V.graph.sizevars.size_hint(s) for s in out.layout.stride
)
)
@ -1034,9 +1034,7 @@ def compile_fx(
"triton.autotune_cublasLt": False,
"triton.cudagraphs": False,
}
), V.set_real_inputs(
example_inputs_
): # type: ignore[call-arg]
), V.set_real_inputs(example_inputs_):
inputs_ = example_inputs_
if isinstance(model_, torch.fx.GraphModule):
fake_inputs = [
@ -1233,7 +1231,7 @@ def compile_fx(
with V.set_fake_mode(fake_mode), compiled_autograd.disable():
return inference_compiler(unlifted_gm, example_inputs_)
with V.set_fake_mode(fake_mode), torch._guards.tracing( # type: ignore[call-arg]
with V.set_fake_mode(fake_mode), torch._guards.tracing(
tracing_context
), compiled_autograd.disable():
return aot_autograd(

View File

@ -291,7 +291,7 @@ compile_threads = decide_compile_threads()
# gemm autotuning global cache dir
if is_fbcode():
from libfb.py import parutil # type: ignore[import]
from libfb.py import parutil
try:
if __package__:

View File

@ -22,7 +22,7 @@ Dep = Union["MemoryDep", "StarDep", "WeakDep"]
class MemoryDep(typing.NamedTuple):
name: str
index: sympy.Expr # type: ignore[assignment]
index: sympy.Expr
var_names: Tuple[sympy.Symbol, ...]
size: Tuple[sympy.Expr, ...]
@ -138,7 +138,7 @@ class WeakDep(typing.NamedTuple):
class IndexExprDep(typing.NamedTuple):
index: sympy.Expr # type: ignore[assignment]
index: sympy.Expr
var_names: Tuple[sympy.Symbol, ...]
size: Tuple[sympy.Expr, ...]
@ -357,7 +357,7 @@ def extract_read_writes(
):
args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
rw = RecordLoadStore(var_ranges, normalize=normalize)
with V.set_ops_handler(rw): # type: ignore[call-arg]
with V.set_ops_handler(rw):
fn(*args)
if normalize:
@ -377,7 +377,7 @@ def extract_read_writes(
def extract_input_node_reduction_ranges( # noqa: F722
input_node: ".ir.TensorBox", # type: ignore[valid-type] # noqa: F722
input_node: "torch._inductor.ir.TensorBox",
) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
"""
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
@ -397,7 +397,7 @@ def extract_input_node_reduction_ranges( # noqa: F722
else:
return (None, None)
if not isinstance(input_node.data.data, Loops):
if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined]
# Other IRNodes do not have reduction_ranges.
return (None, None)

View File

@ -14,9 +14,7 @@ from ..pattern_matcher import (
)
if config.is_fbcode():
from torch._inductor.fb.utils import ( # type: ignore[import] # noqa: F401
get_everpaste_url,
)
from torch._inductor.fb.utils import get_everpaste_url
try:
# importing this will register fbgemm lowerings for inductor

View File

@ -101,7 +101,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
gm.graph.lint()
if config.is_fbcode():
from torch._inductor.fb.utils import get_everpaste_url # type: ignore[import]
from torch._inductor.fb.utils import get_everpaste_url
log.info(
"Print graph after recompile in post grad passes: %s",

View File

@ -78,7 +78,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
gm.recompile()
if config.is_fbcode():
from torch._inductor.fb.utils import get_everpaste_url # type: ignore[import]
from torch._inductor.fb.utils import get_everpaste_url
log.info(
"Print graph after recompile in pre grad passes: %s",

View File

@ -984,7 +984,7 @@ class Reduction(Loops):
if split == -1:
assert input_node is not None
new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
input_node
input_node # type: ignore[arg-type]
)
assert new_ranges is not None
assert new_reduction_ranges is not None

View File

@ -214,9 +214,9 @@ def mm_args(mat1, mat2, *others, layout=None, out_dtype=None, use_4x2_dim=False)
def addmm_epilogue(dtype, alpha, beta):
def epilogue(acc, bias):
if alpha != 1:
acc = V.ops.mul(acc, V.ops.constant(alpha, dtype)) # type: ignore[attr-defined]
acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
if beta != 1:
bias = V.ops.mul(bias, V.ops.constant(beta, dtype)) # type: ignore[attr-defined]
return V.ops.add(acc, bias) # type: ignore[attr-defined]
bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
return V.ops.add(acc, bias)
return epilogue

View File

@ -265,7 +265,7 @@ class TritonTemplateKernel(TritonKernel):
input_node.freeze_layout()
epilogue_args.append(input_node.make_loader()(index_symbols))
V.ops.store( # type: ignore[attr-defined]
V.ops.store(
self.output_node.get_name(),
output_index,
self.epilogue_fn(*epilogue_args),
@ -649,7 +649,7 @@ class ExternKernelCaller(ChoiceCaller):
else:
algo = self.to_callable()
out_new = algo(*args)
torch._C._dynamo.guards.assert_size_stride( # type: ignore[attr-defined]
torch._C._dynamo.guards.assert_size_stride(
out_new, tuple(out.size()), tuple(out.stride())
)
out.copy_(out_new) # for correctness checking

View File

@ -989,7 +989,7 @@ def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix=""):
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
try:
import colorama # type: ignore[import]
import colorama
if ms > 0.012 and gb_per_s < 650:
info_str = colorama.Fore.RED + info_str + colorama.Fore.RESET

View File

@ -138,7 +138,7 @@ class KernelFormatterHandler:
)
formatter.output.writeline(f"{lhs} = {name}")
with V.set_ops_handler(formatter), patch.object( # type: ignore[call-arg]
with V.set_ops_handler(formatter), patch.object(
FlexibleLayout, "allow_indexing", True
):
result = ir_fn(*args)