mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Delete a bunch of type-ignores (#113990)
* Replaced `ignore[import]` by mypy config file entries * Removed a bunch of ignores around previously-fixed attr-defined / call-arg issues * Fixed some invalid / undefined types; added a few more type-ignores to squelch the downstream errors this exposed Pull Request resolved: https://github.com/pytorch/pytorch/pull/113990 Approved by: https://github.com/eellison, https://github.com/Skylion007 ghstack dependencies: #113979
This commit is contained in:
@ -24,6 +24,12 @@ files =
|
||||
# understood by mypy.
|
||||
python_version = 3.11
|
||||
|
||||
[mypy-colorama.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-cutlass_library.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-deeplearning.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@ -33,11 +39,17 @@ ignore_missing_imports = True
|
||||
[mypy-einops.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-libfb.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# sympy is too dynamic, hard to type properly
|
||||
[mypy-sympy.*]
|
||||
ignore_missing_imports = True
|
||||
follow_imports = skip
|
||||
|
||||
[mypy-torch.*.fb.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# FIXME: importing this creates lots of type errors
|
||||
[mypy-torch._dynamo.variables.*]
|
||||
follow_imports = skip
|
||||
|
||||
@ -34,7 +34,7 @@ def openxla_eval_boxed(model, fake_tensor_inputs):
|
||||
|
||||
def xla_backend_helper(model, fake_tensor_inputs, boxed=False):
|
||||
try:
|
||||
import torch_xla.core.dynamo_bridge as bridge # type: ignore[import]
|
||||
import torch_xla.core.dynamo_bridge as bridge
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla"
|
||||
|
||||
@ -34,7 +34,7 @@ inductor_config = import_module("torch._inductor.config")
|
||||
use_buck = inductor_config.is_fbcode()
|
||||
|
||||
if use_buck:
|
||||
import libfb.py.build_info # type: ignore[import]
|
||||
import libfb.py.build_info
|
||||
|
||||
|
||||
extra_deps = []
|
||||
|
||||
@ -157,7 +157,7 @@ DONT_WRAP_FILES = {
|
||||
|
||||
def _debug_get_cache_entry_list(
|
||||
code: Union[types.CodeType, Callable[..., Any]]
|
||||
) -> List[CacheEntry]: # type: ignore[valid-type]
|
||||
) -> List[CacheEntry]:
|
||||
"""
|
||||
Given a code object or a callable object, retrieve the cache entries
|
||||
stored in this code.
|
||||
|
||||
@ -53,7 +53,7 @@ class DynamoCallbackFn(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
frame: DynamoFrameType,
|
||||
cache_entry: Optional[CacheEntry], # type: ignore[valid-type]
|
||||
cache_entry: Optional[CacheEntry],
|
||||
frame_state: FrameState,
|
||||
) -> Optional[GuardedCode]:
|
||||
...
|
||||
|
||||
@ -65,7 +65,7 @@ if config.is_fbcode():
|
||||
from triton.fb import build_paths
|
||||
from triton.fb.build import _run_build_command
|
||||
|
||||
from torch._inductor.fb.utils import ( # type: ignore[import]
|
||||
from torch._inductor.fb.utils import (
|
||||
log_global_cache_errors,
|
||||
log_global_cache_stats,
|
||||
log_global_cache_vals,
|
||||
|
||||
@ -78,7 +78,7 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
# We can fuse a Pointwise op that depends on the last fused epilogue node
|
||||
# if any. If there is no epilogue node yet, it needs to depend on the template
|
||||
# node
|
||||
node_name = additional_node.get_computed_buffer_name() # type: ignore[attr-defined]
|
||||
node_name = additional_node.get_computed_buffer_name()
|
||||
if node_name is None:
|
||||
return False
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ class CutlassEVTEpilogueTypeFormatter:
|
||||
template_output_node_name, evt_type_name
|
||||
)
|
||||
|
||||
with virtualized.V.set_ops_handler(formatter), patch.object( # type: ignore[call-arg]
|
||||
with virtualized.V.set_ops_handler(formatter), patch.object(
|
||||
FlexibleLayout, "allow_indexing", True
|
||||
):
|
||||
for node in epilogue_nodes:
|
||||
@ -253,7 +253,7 @@ class CutlassEVTEpilogueArgumentFormatter:
|
||||
template_output_node_name,
|
||||
)
|
||||
|
||||
with virtualized.V.set_ops_handler(formatter), patch.object( # type: ignore[call-arg]
|
||||
with virtualized.V.set_ops_handler(formatter), patch.object(
|
||||
FlexibleLayout, "allow_indexing", True
|
||||
):
|
||||
for node in epilogue_nodes:
|
||||
|
||||
@ -3,8 +3,8 @@ from ..cutlass_utils import try_import_cutlass
|
||||
if try_import_cutlass():
|
||||
import enum
|
||||
|
||||
from cutlass_library.library import * # type: ignore[import] # noqa: F401, F403
|
||||
from cutlass_library.gemm_operation import * # type: ignore[import] # noqa: F401, F403
|
||||
from cutlass_library.library import * # noqa: F401, F403
|
||||
from cutlass_library.gemm_operation import * # noqa: F401, F403
|
||||
|
||||
# copied / modified from original at
|
||||
# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658
|
||||
|
||||
@ -72,9 +72,9 @@ def try_import_cutlass() -> bool:
|
||||
os.symlink(cutlass_py_full_path, dst_link)
|
||||
sys.path.append(tmp_cutlass_py_full_path)
|
||||
try:
|
||||
import cutlass_library.generator # type: ignore[import] # noqa: F401
|
||||
import cutlass_library.library # type: ignore[import] # noqa: F401
|
||||
import cutlass_library.manifest # type: ignore[import] # noqa: F401
|
||||
import cutlass_library.generator # noqa: F401
|
||||
import cutlass_library.library # noqa: F401
|
||||
import cutlass_library.manifest # noqa: F401
|
||||
|
||||
return True
|
||||
|
||||
@ -139,8 +139,8 @@ def _gen_ops_cached(arch, version) -> List[Any]:
|
||||
|
||||
# Import cutlass python scripts.
|
||||
assert try_import_cutlass()
|
||||
import cutlass_library.generator as cutlass_generator # type: ignore[import]
|
||||
import cutlass_library.manifest as cutlass_manifest # type: ignore[import]
|
||||
import cutlass_library.generator as cutlass_generator
|
||||
import cutlass_library.manifest as cutlass_manifest
|
||||
|
||||
if arch is None or version is None:
|
||||
log.error(
|
||||
@ -184,7 +184,7 @@ def dtype_match(
|
||||
) -> bool:
|
||||
# Import cutlass python scripts.
|
||||
assert try_import_cutlass()
|
||||
import cutlass_library # type: ignore[import]
|
||||
import cutlass_library
|
||||
|
||||
if torch_dtype == torch.float:
|
||||
return (
|
||||
|
||||
@ -266,7 +266,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
@staticmethod
|
||||
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined]
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
if torch_layout.stride[-1] == 1:
|
||||
return cutlass_lib.LayoutType.RowMajor
|
||||
@ -280,7 +280,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined]
|
||||
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined]
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
|
||||
return cutlass_lib.LayoutType.ColumnMajor
|
||||
@ -303,7 +303,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
@staticmethod
|
||||
def has_tma_epilogue(op) -> bool:
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
result = False
|
||||
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
||||
@ -320,7 +320,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
See https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L283-L285 # noqa: B950
|
||||
"""
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
if op.gemm_kind != cutlass_lib.GemmKind.Universal3x:
|
||||
return False
|
||||
@ -350,8 +350,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
epilogue_nodes: Optional[List[IRNode]] = None,
|
||||
) -> Tuple[str, str]:
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.gemm_operation as cutlass_gemm_op
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
from torch._inductor.codegen.cuda.cutlass_lib_extensions.gemm_operation_extensions import (
|
||||
EmitGemmUniversal3xInstanceWithEVT,
|
||||
@ -424,7 +424,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined]
|
||||
) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined]
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
# Skip simt kernels
|
||||
if (
|
||||
@ -510,8 +510,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
|
||||
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined]
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.gemm_operation as cutlass_gemm_op
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
|
||||
res: Dict[str, cutlass_gemm_op.GemmOperation] = dict()
|
||||
@ -645,8 +645,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
)
|
||||
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.gemm_operation as cutlass_gemm_op
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
assert isinstance(
|
||||
op, cutlass_gemm_op.GemmOperation
|
||||
|
||||
@ -2530,7 +2530,7 @@ class TritonScheduling(BaseScheduling):
|
||||
|
||||
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
|
||||
|
||||
with V.set_kernel_handler(kernel): # type: ignore[call-arg]
|
||||
with V.set_kernel_handler(kernel):
|
||||
src_code = kernel.codegen_kernel()
|
||||
|
||||
for node in node_schedule:
|
||||
@ -2651,7 +2651,7 @@ class TritonScheduling(BaseScheduling):
|
||||
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
|
||||
|
||||
# finalize must be called after adding epilogue above
|
||||
with V.set_kernel_handler(kernel): # type: ignore[call-arg]
|
||||
with V.set_kernel_handler(kernel):
|
||||
# TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
|
||||
src_code = (
|
||||
partial_code
|
||||
@ -2696,7 +2696,7 @@ class TritonScheduling(BaseScheduling):
|
||||
subkernel,
|
||||
)
|
||||
|
||||
with V.set_kernel_handler(subkernel): # type: ignore[call-arg]
|
||||
with V.set_kernel_handler(subkernel):
|
||||
for node in node_schedule:
|
||||
if node not in (EnableReduction, DisableReduction):
|
||||
node.mark_run()
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
import math
|
||||
from enum import IntEnum
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from . import ir
|
||||
|
||||
from .utils import get_dtype_size, sympy_product
|
||||
from .virtualized import V
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
class NCCL_COLL(IntEnum):
|
||||
ALL_REDUCE = 0
|
||||
@ -33,7 +38,7 @@ def get_gpu_type() -> NVIDIA_GPU_TYPE:
|
||||
return NVIDIA_GPU_TYPE.AMPERE
|
||||
|
||||
|
||||
def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL: # type: ignore[name-defined]
|
||||
def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL:
|
||||
if isinstance(snode.node, (ir.AllReduce, ir.AllReduceCoalesced)):
|
||||
return NCCL_COLL.ALL_REDUCE
|
||||
elif isinstance(
|
||||
@ -136,7 +141,7 @@ llMaxBws = torch.tensor(
|
||||
)
|
||||
|
||||
|
||||
def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: # type: ignore[name-defined]
|
||||
def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
"""
|
||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||
|
||||
@ -158,7 +163,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: # ty
|
||||
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
|
||||
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
||||
num_gpus_per_node = 8
|
||||
_, _, group_size = snode.node.constant_args
|
||||
_, _, group_size = snode.node.constant_args # type: ignore[attr-defined]
|
||||
nNodes = math.ceil(group_size / num_gpus_per_node)
|
||||
nRanks = group_size # this is total # of gpus globally that participate in this collective op
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ from .utils import get_dtype_size, has_incompatible_cudagraph_ops
|
||||
from .virtualized import V
|
||||
|
||||
if config.is_fbcode():
|
||||
from torch._inductor.fb.utils import time_and_log # type: ignore[import]
|
||||
from torch._inductor.fb.utils import time_and_log
|
||||
else:
|
||||
# no-op decorator
|
||||
def time_and_log(attr: str):
|
||||
@ -208,7 +208,7 @@ def count_bytes_inner(
|
||||
post_grad_passes(gm, False)
|
||||
|
||||
graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
|
||||
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs): # type: ignore[call-arg]
|
||||
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs):
|
||||
graph.run(*example_inputs)
|
||||
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
||||
metrics.num_bytes_accessed += num_bytes
|
||||
@ -610,7 +610,7 @@ def fx_codegen_and_compile(
|
||||
for out in graph.graph_outputs:
|
||||
if hasattr(out, "layout"):
|
||||
output_strides.append(
|
||||
tuple( # type: ignore[arg-type]
|
||||
tuple(
|
||||
V.graph.sizevars.size_hint(s) for s in out.layout.stride
|
||||
)
|
||||
)
|
||||
@ -1034,9 +1034,7 @@ def compile_fx(
|
||||
"triton.autotune_cublasLt": False,
|
||||
"triton.cudagraphs": False,
|
||||
}
|
||||
), V.set_real_inputs(
|
||||
example_inputs_
|
||||
): # type: ignore[call-arg]
|
||||
), V.set_real_inputs(example_inputs_):
|
||||
inputs_ = example_inputs_
|
||||
if isinstance(model_, torch.fx.GraphModule):
|
||||
fake_inputs = [
|
||||
@ -1233,7 +1231,7 @@ def compile_fx(
|
||||
with V.set_fake_mode(fake_mode), compiled_autograd.disable():
|
||||
return inference_compiler(unlifted_gm, example_inputs_)
|
||||
|
||||
with V.set_fake_mode(fake_mode), torch._guards.tracing( # type: ignore[call-arg]
|
||||
with V.set_fake_mode(fake_mode), torch._guards.tracing(
|
||||
tracing_context
|
||||
), compiled_autograd.disable():
|
||||
return aot_autograd(
|
||||
|
||||
@ -291,7 +291,7 @@ compile_threads = decide_compile_threads()
|
||||
|
||||
# gemm autotuning global cache dir
|
||||
if is_fbcode():
|
||||
from libfb.py import parutil # type: ignore[import]
|
||||
from libfb.py import parutil
|
||||
|
||||
try:
|
||||
if __package__:
|
||||
|
||||
@ -22,7 +22,7 @@ Dep = Union["MemoryDep", "StarDep", "WeakDep"]
|
||||
|
||||
class MemoryDep(typing.NamedTuple):
|
||||
name: str
|
||||
index: sympy.Expr # type: ignore[assignment]
|
||||
index: sympy.Expr
|
||||
var_names: Tuple[sympy.Symbol, ...]
|
||||
size: Tuple[sympy.Expr, ...]
|
||||
|
||||
@ -138,7 +138,7 @@ class WeakDep(typing.NamedTuple):
|
||||
|
||||
|
||||
class IndexExprDep(typing.NamedTuple):
|
||||
index: sympy.Expr # type: ignore[assignment]
|
||||
index: sympy.Expr
|
||||
var_names: Tuple[sympy.Symbol, ...]
|
||||
size: Tuple[sympy.Expr, ...]
|
||||
|
||||
@ -357,7 +357,7 @@ def extract_read_writes(
|
||||
):
|
||||
args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
|
||||
rw = RecordLoadStore(var_ranges, normalize=normalize)
|
||||
with V.set_ops_handler(rw): # type: ignore[call-arg]
|
||||
with V.set_ops_handler(rw):
|
||||
fn(*args)
|
||||
|
||||
if normalize:
|
||||
@ -377,7 +377,7 @@ def extract_read_writes(
|
||||
|
||||
|
||||
def extract_input_node_reduction_ranges( # noqa: F722
|
||||
input_node: ".ir.TensorBox", # type: ignore[valid-type] # noqa: F722
|
||||
input_node: "torch._inductor.ir.TensorBox",
|
||||
) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
|
||||
"""
|
||||
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
|
||||
@ -397,7 +397,7 @@ def extract_input_node_reduction_ranges( # noqa: F722
|
||||
else:
|
||||
return (None, None)
|
||||
|
||||
if not isinstance(input_node.data.data, Loops):
|
||||
if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined]
|
||||
# Other IRNodes do not have reduction_ranges.
|
||||
return (None, None)
|
||||
|
||||
|
||||
@ -14,9 +14,7 @@ from ..pattern_matcher import (
|
||||
)
|
||||
|
||||
if config.is_fbcode():
|
||||
from torch._inductor.fb.utils import ( # type: ignore[import] # noqa: F401
|
||||
get_everpaste_url,
|
||||
)
|
||||
from torch._inductor.fb.utils import get_everpaste_url
|
||||
|
||||
try:
|
||||
# importing this will register fbgemm lowerings for inductor
|
||||
|
||||
@ -101,7 +101,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
gm.graph.lint()
|
||||
|
||||
if config.is_fbcode():
|
||||
from torch._inductor.fb.utils import get_everpaste_url # type: ignore[import]
|
||||
from torch._inductor.fb.utils import get_everpaste_url
|
||||
|
||||
log.info(
|
||||
"Print graph after recompile in post grad passes: %s",
|
||||
|
||||
@ -78,7 +78,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
|
||||
gm.recompile()
|
||||
|
||||
if config.is_fbcode():
|
||||
from torch._inductor.fb.utils import get_everpaste_url # type: ignore[import]
|
||||
from torch._inductor.fb.utils import get_everpaste_url
|
||||
|
||||
log.info(
|
||||
"Print graph after recompile in pre grad passes: %s",
|
||||
|
||||
@ -984,7 +984,7 @@ class Reduction(Loops):
|
||||
if split == -1:
|
||||
assert input_node is not None
|
||||
new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
|
||||
input_node
|
||||
input_node # type: ignore[arg-type]
|
||||
)
|
||||
assert new_ranges is not None
|
||||
assert new_reduction_ranges is not None
|
||||
|
||||
@ -214,9 +214,9 @@ def mm_args(mat1, mat2, *others, layout=None, out_dtype=None, use_4x2_dim=False)
|
||||
def addmm_epilogue(dtype, alpha, beta):
|
||||
def epilogue(acc, bias):
|
||||
if alpha != 1:
|
||||
acc = V.ops.mul(acc, V.ops.constant(alpha, dtype)) # type: ignore[attr-defined]
|
||||
acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
|
||||
if beta != 1:
|
||||
bias = V.ops.mul(bias, V.ops.constant(beta, dtype)) # type: ignore[attr-defined]
|
||||
return V.ops.add(acc, bias) # type: ignore[attr-defined]
|
||||
bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
|
||||
return V.ops.add(acc, bias)
|
||||
|
||||
return epilogue
|
||||
|
||||
@ -265,7 +265,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
input_node.freeze_layout()
|
||||
epilogue_args.append(input_node.make_loader()(index_symbols))
|
||||
|
||||
V.ops.store( # type: ignore[attr-defined]
|
||||
V.ops.store(
|
||||
self.output_node.get_name(),
|
||||
output_index,
|
||||
self.epilogue_fn(*epilogue_args),
|
||||
@ -649,7 +649,7 @@ class ExternKernelCaller(ChoiceCaller):
|
||||
else:
|
||||
algo = self.to_callable()
|
||||
out_new = algo(*args)
|
||||
torch._C._dynamo.guards.assert_size_stride( # type: ignore[attr-defined]
|
||||
torch._C._dynamo.guards.assert_size_stride(
|
||||
out_new, tuple(out.size()), tuple(out.stride())
|
||||
)
|
||||
out.copy_(out_new) # for correctness checking
|
||||
|
||||
@ -989,7 +989,7 @@ def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
|
||||
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix=""):
|
||||
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
|
||||
try:
|
||||
import colorama # type: ignore[import]
|
||||
import colorama
|
||||
|
||||
if ms > 0.012 and gb_per_s < 650:
|
||||
info_str = colorama.Fore.RED + info_str + colorama.Fore.RESET
|
||||
|
||||
@ -138,7 +138,7 @@ class KernelFormatterHandler:
|
||||
)
|
||||
formatter.output.writeline(f"{lhs} = {name}")
|
||||
|
||||
with V.set_ops_handler(formatter), patch.object( # type: ignore[call-arg]
|
||||
with V.set_ops_handler(formatter), patch.object(
|
||||
FlexibleLayout, "allow_indexing", True
|
||||
):
|
||||
result = ir_fn(*args)
|
||||
|
||||
Reference in New Issue
Block a user