diff --git a/pyrefly.toml b/pyrefly.toml index c1a9e3b003fa..3f048d2249c9 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -20,8 +20,10 @@ project-includes = [ project-excludes = [ # ==== below will be enabled directory by directory ==== # ==== to test Pyrefly on a specific directory, simply comment it out ==== - "torch/_inductor/**", - # formatting issues + "torch/_inductor/runtime", + "torch/_inductor/codegen", + # formatting issues, will turn on after adjusting where suppressions can be + # in import statements "torch/linalg/__init__.py", "torch/package/importer.py", "torch/package/_package_pickler.py", @@ -31,6 +33,8 @@ project-excludes = [ "torch/_export/utils.py", "torch/fx/experimental/unification/multipledispatch/__init__.py", "torch/nn/modules/__init__.py", + "torch/_inductor/codecache.py", + "torch/distributed/elastic/metrics/__init__.py", # ==== "benchmarks/instruction_counts/main.py", "benchmarks/instruction_counts/definitions/setup.py", diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index d287337afaa6..9c1090684016 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -132,6 +132,7 @@ def aoti_compile_and_package( ) or ( isinstance(package_path, (str, os.PathLike)) + # pyrefly: ignore # no-matching-overload and os.fspath(package_path).endswith(".pt2") ) ), ( @@ -151,6 +152,7 @@ def aoti_compile_and_package( return aot_inductor_minifier_wrapper( _aoti_compile_and_package_inner, exported_program, + # pyrefly: ignore # bad-argument-type package_path=package_path, inductor_configs=inductor_configs, ) diff --git a/torch/_inductor/analysis/profile_analysis.py b/torch/_inductor/analysis/profile_analysis.py index 134d06528c0d..a9f89009c210 100644 --- a/torch/_inductor/analysis/profile_analysis.py +++ b/torch/_inductor/analysis/profile_analysis.py @@ -49,6 +49,7 @@ def register_adapter( AdapterType, ]: def decorator(func: AdapterType) -> AdapterType: + # pyrefly: ignore # unknown-name global _adapters_map if isinstance(aten, str): @@ -412,9 +413,11 @@ class JsonProfile: if dtype is None: self.dtype = None elif isinstance(dtype, torch.dtype): + # pyrefly: ignore # bad-assignment self.dtype = dtype else: if dtype in _dtype_map: + # pyrefly: ignore # bad-assignment self.dtype = _dtype_map[dtype] else: self.dtype = None @@ -653,6 +656,7 @@ class JsonProfile: t1, self_name, t2, other_name ) tab_string = create_ret(table_headers, table_rows) + # pyrefly: ignore # bad-argument-type ret.append(f"{self._devices[device_idx]}:\n{tab_string}") return "\n".join(ret) self._compute_stats() @@ -663,6 +667,7 @@ class JsonProfile: for idx, table in self_tables.items(): table_headers, table_rows = table tab_string = create_ret(table_headers, table_rows) + # pyrefly: ignore # bad-argument-type ret.append(f"{self._devices[idx]}:\n{tab_string}") return "\n".join(ret) diff --git a/torch/_inductor/analyze_preserves_zero_mask.py b/torch/_inductor/analyze_preserves_zero_mask.py index 90d0ff80c5f0..0096103670a3 100644 --- a/torch/_inductor/analyze_preserves_zero_mask.py +++ b/torch/_inductor/analyze_preserves_zero_mask.py @@ -106,6 +106,7 @@ class RecordLowPrecisionOps(DefaultHandler): pass @staticmethod + # pyrefly: ignore # bad-override def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr: return sympy.S.Zero diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index a504b54f132b..85ea0a79d5f4 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -557,6 +557,7 @@ class GPUDeviceBenchmarkMixin: res = benchmarker.benchmark_gpu(fn) device_interface.synchronize() # shake out any CUDA errors + # pyrefly: ignore # bad-return return res diff --git a/torch/_inductor/await_utils.py b/torch/_inductor/await_utils.py index a549674d5cd7..036c7e3457d7 100644 --- a/torch/_inductor/await_utils.py +++ b/torch/_inductor/await_utils.py @@ -149,6 +149,7 @@ def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[ task_factory = task_factories[0] if task_factory is None: if sys.version_info >= (3, 11): + # pyrefly: ignore # bad-argument-type task = asyncio.Task(coro, loop=loop, context=context) else: task = asyncio.Task(coro, loop=loop) diff --git a/torch/_inductor/cache.py b/torch/_inductor/cache.py index aff54a126b31..07ff7912e3cd 100644 --- a/torch/_inductor/cache.py +++ b/torch/_inductor/cache.py @@ -292,6 +292,7 @@ class OnDiskCache(AsyncCache[Key, Value]): raise CacheError( f"Failed to get fpath for key {key!r}, key is not pickle-able." ) from err + # pyrefly: ignore # bad-argument-type assert_never(key) def _flock_from_fpath(self: Self, fpath: Path) -> FileLock: @@ -306,6 +307,7 @@ class OnDiskCache(AsyncCache[Key, Value]): # for fpath.name[:4]; this is more than enough unique locks to not # cause additional overhead from shared locks and it also saves our # cache dir from becoming 50 percent locks + # pyrefly: ignore # bad-return return FileLock(str(fpath.parent / "locks" / fpath.name[:4]) + ".lock") @property diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index a2f1f05183b1..84c6a3089d80 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -588,6 +588,7 @@ class InductorChoices: and memory_score > 0 ) + # pyrefly: ignore # bad-return return ( template_score, node1.is_reduction() == node2.is_reduction() and memory_score > 0, diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 70351653a0e0..9025ecdc56fa 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2171,6 +2171,7 @@ end data_ptr, ctypes.POINTER(ctypes.c_ubyte * nbytes), ) + # pyrefly: ignore # missing-attribute raw_bytes = bytes(raw_array.contents) return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) @@ -2362,6 +2363,7 @@ end ): current_arch = _nvcc_arch_as_compile_option() cmd = ( + # pyrefly: ignore # unbound-name f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " # Triton only allows generating PTX version as same as the current arch f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " @@ -2564,6 +2566,7 @@ def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, No # convert any kwarg-only arguments to kwargs kwargs = dict() + # pyrefly: ignore # missing-attribute for func_arg, conv_arg in zip(func._schema.arguments, converted_args): if func_arg.kwarg_only: kwargs[func_arg.name] = conv_arg @@ -2745,10 +2748,14 @@ class CppCodeCache: main_build_option = CppTorchDeviceOptions( compile_only=bool(optimized_code), min_optimize=optimized_code is not None, + # pyrefly: ignore # bad-argument-type **compile_command, ) optimized_build_option = CppTorchDeviceOptions( - compile_only=True, **compile_command + # pyrefly: ignore # bad-argument-type + compile_only=True, + # pyrefly: ignore # bad-argument-type + **compile_command, ) def get_hashable_command_line(build_option: BuildOptionsBase) -> str: @@ -2797,6 +2804,7 @@ class CppCodeCache: # decision if that ever changes. if optimized_code and (header := _get_cpp_prefix_header(device_type)): optimized_build_option.precompiled_header = _precompile_header( + # pyrefly: ignore # unbound-name header, optimized_cmd_line, **compile_command, @@ -2827,6 +2835,7 @@ class CppCodeCache: main_builder.get_target_file_path(), optimized_builder.get_target_file_path(), ], + # pyrefly: ignore # bad-argument-type BuildOption=CppTorchDeviceOptions(**compile_command), output_dir=output_dir, ) @@ -2985,6 +2994,7 @@ class CppPythonBindingsCodeCache(CppCodeCache): ) @classmethod + # pyrefly: ignore # bad-override def _load_library_inner(cls, path: str, key: str) -> ModuleType: os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str( torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr # type: ignore[attr-defined] @@ -3256,10 +3266,12 @@ class HalideCodeCache(CppPythonBindingsCodeCache): buffer_names = [] for i, arg in enumerate(meta.argtypes): if arg.is_buffer(): + # pyrefly: ignore # bad-argument-type buffer_names.append(f"&hl_buf_{i}") buffers.extend(cls._codegen_buffer(f"hl_buf_{i}", arg, is_cuda)) else: assert "*" not in arg.ctype + # pyrefly: ignore # bad-argument-type buffer_names.append(arg.name) buffers = "\n".join([f" {line}" for line in buffers]).lstrip() @@ -3514,6 +3526,7 @@ def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None: ci = cmd.index("-o") assert isinstance(ci, int) + # pyrefly: ignore # unsupported-operation cmd[ci + 1] = Out() repl = textwrap.indent( textwrap.dedent( diff --git a/torch/_inductor/comm_lowering.py b/torch/_inductor/comm_lowering.py index 6d1d4b6cf293..5d2e39d79307 100644 --- a/torch/_inductor/comm_lowering.py +++ b/torch/_inductor/comm_lowering.py @@ -208,6 +208,7 @@ def register_comm_lowerings(): # in-place reuse. Therefore, we tell the scheduler to not fuse it. inp.realize() V.graph.no_fuse_buffer_names.add(inp.get_name()) + # pyrefly: ignore # bad-assignment inp = ir.ExternKernel.require_contiguous(inp) # Because we are lowering as inplace c10d.all_reduce_, we should generate # _AllReduce_Kernel instead of _AllReduceKernel. @@ -232,6 +233,7 @@ def register_comm_lowerings(): return inp # Lower as c10d.all_reduce_ + # pyrefly: ignore # bad-assignment inp = ir.ExternKernel.require_contiguous(inp) ir._AllReduce_Kernel.create_inplace( c10d.all_reduce_.default, diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index b57e64e296f3..86f272c8b24e 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -465,6 +465,7 @@ def _reorder_communication_preserving_peak_memory_internal( while _next[curr] is not None: if iterative_recompute_error: break + # pyrefly: ignore # bad-argument-type if contains_collective(curr): if debug_num_collectives_to_reorder is not None and ( num_processed_collectives >= debug_num_collectives_to_reorder @@ -825,8 +826,11 @@ def _schedule_for_comm( collective_cost > 0 and (candidate := get_overlapping_candidate()) is not None ): + # pyrefly: ignore # unbound-name ready.remove(candidate) + # pyrefly: ignore # unbound-name schedule(candidate.snode) + # pyrefly: ignore # unbound-name collective_cost -= snode_to_cost[candidate.snode] heapq.heapify(ready) @@ -1028,6 +1032,7 @@ def _sink_waits_iterative_internal( ): break + # pyrefly: ignore # bad-argument-type if contains_wait(curr) and curr not in processed_waits: processed_waits.add(curr) info = stats[curr] = SinkWaitInfo() @@ -1093,6 +1098,7 @@ def _sink_waits_iterative_internal( info.grouped_info = _group_names(gns) candidate = _next[candidate] continue + # pyrefly: ignore # unbound-name elif (data_dep is None) and both_contain_comms: info.limiting_factor = ( f"collective ordering {_group_names(gns)}" @@ -1365,6 +1371,7 @@ def reorder_compute_and_comm_for_overlap( snodes, get_freeable_input_buf(snodes, graph_inputs), graph_outputs ) print(f"final {peak_memory=}") + # pyrefly: ignore # bad-return return order @@ -1632,6 +1639,7 @@ def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: KeywordArg("group_size"), KeywordArg("group_name"), ), + # pyrefly: ignore # bad-argument-type pass_dict=graph_pass, extra_check=lambda match: match.kwargs["item_idx"] == 0, ) @@ -1655,6 +1663,7 @@ def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: return all_gather_into_tensor match.replace_by_example( + # pyrefly: ignore # bad-argument-type repl, [ kwargs["all_gather_inputs"], diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index bea42883ce6c..679bfbaac46c 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -271,8 +271,10 @@ def record_original_output_strides(gm: GraphModule) -> None: and (val := output.meta.get("val")) is not None and isinstance(val, torch.Tensor) ): + # pyrefly: ignore # unbound-name output_strides.append(val.stride()) else: + # pyrefly: ignore # bad-argument-type output_strides.append(None) output_node.meta["original_output_strides"] = output_strides @@ -1546,6 +1548,7 @@ class _InProcessFxCompile(FxCompile): node_runtimes = None if inductor_metrics_log.isEnabledFor(logging.INFO): num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() + # pyrefly: ignore # bad-assignment metrics.num_bytes_accessed += num_bytes metrics.node_runtimes += node_runtimes metrics.nodes_num_elem += nodes_num_elem @@ -1589,8 +1592,10 @@ class _InProcessFxCompile(FxCompile): disable = f"{disable} Found from {stack_trace}\n" else: disable = f"{disable}\n" + # pyrefly: ignore # unbound-name V.graph.disable_cudagraphs_reason = disable + # pyrefly: ignore # unbound-name if cudagraphs and not V.graph.disable_cudagraphs_reason: maybe_incompat_node = get_first_incompatible_cudagraph_node(gm) if maybe_incompat_node: @@ -1599,22 +1604,29 @@ class _InProcessFxCompile(FxCompile): "stack_trace", None ): disable = f"{disable} Found from {stack_trace}\n" + # pyrefly: ignore # unbound-name V.graph.disable_cudagraphs_reason = disable + # pyrefly: ignore # unbound-name if V.aot_compilation: assert isinstance( - compiled_fn, (str, list, torch.fx.GraphModule) + compiled_fn, + # pyrefly: ignore # unbound-name + (str, list, torch.fx.GraphModule), ), type(compiled_fn) return CompiledAOTI(compiled_fn) # TODO: Hoist this above V.aot_compilation + # pyrefly: ignore # unbound-name if cudagraphs and not V.graph.disable_cudagraphs_reason: from torch._inductor.cudagraph_utils import ( check_lowering_disable_cudagraph, ) + # pyrefly: ignore # unbound-name V.graph.disable_cudagraphs_reason = ( check_lowering_disable_cudagraph( + # pyrefly: ignore # unbound-name V.graph.device_node_mapping ) ) @@ -1622,23 +1634,29 @@ class _InProcessFxCompile(FxCompile): self._compile_stats[type(self)].codegen_and_compile += 1 if ( + # pyrefly: ignore # unbound-name torch._inductor.debug.RECORD_GRAPH_EXECUTION + # pyrefly: ignore # unbound-name and torch._inductor.debug.GRAPH_COMPILE_IDS is not None ): compile_id = str( + # pyrefly: ignore # unbound-name torch._guards.CompileContext.current_compile_id() ) graph_id = graph_kwargs.get("graph_id") if graph_id is not None: + # pyrefly: ignore # unbound-name torch._inductor.debug.GRAPH_COMPILE_IDS[graph_id] = ( compile_id ) return CompiledFxGraph( + # pyrefly: ignore # bad-argument-type compiled_fn, graph, gm, output_strides, + # pyrefly: ignore # unbound-name V.graph.disable_cudagraphs_reason, metrics_helper.get_deltas(), counters["inductor"] - inductor_counters, @@ -1680,15 +1698,18 @@ def fx_codegen_and_compile( from .compile_fx_async import _AsyncFxCompile from .compile_fx_ext import _OutOfProcessFxCompile + # pyrefly: ignore # unbound-name assert isinstance(scheme, _OutOfProcessFxCompile), ( "async is only valid with an out-of-process compile mode" ) + # pyrefly: ignore # unbound-name scheme = _AsyncFxCompile(scheme) if fx_compile_progressive: from .compile_fx_async import _ProgressiveFxCompile from .compile_fx_ext import _OutOfProcessFxCompile + # pyrefly: ignore # unbound-name assert isinstance(scheme, _OutOfProcessFxCompile), ( "progressive is only valid with an out-of-process compile mode" ) @@ -1698,8 +1719,10 @@ def fx_codegen_and_compile( # Use in-process compile for the fast version fast_scheme = _InProcessFxCompile() + # pyrefly: ignore # unbound-name scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs) + # pyrefly: ignore # unbound-name return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) @@ -1809,6 +1832,7 @@ def cudagraphify_impl( Assumes inputs[static_input_idxs[i]] are always the same memory address """ check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] + # pyrefly: ignore # annotation-mismatch static_input_idxs: OrderedSet[int] = OrderedSet( remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] ) @@ -1875,6 +1899,7 @@ def cudagraphify_impl( index_expanded_dims_and_copy_(dst, src, expanded_dims) new_inputs.clear() graph.replay() + # pyrefly: ignore # bad-return return static_outputs else: @@ -1890,6 +1915,7 @@ def cudagraphify_impl( index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims) new_inputs.clear() graph.replay() + # pyrefly: ignore # bad-return return static_outputs return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet()) @@ -1906,6 +1932,7 @@ def compile_fx_aot( # [See NOTE] Unwrapping subclasses AOT unwrap_tensor_subclass_parameters(model_) + # pyrefly: ignore # annotation-mismatch config_patches: dict[str, Any] = copy.deepcopy(config_patches or {}) if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper): @@ -2848,6 +2875,7 @@ def _aoti_flatten_inputs( Flatten the inputs to the graph module and return the flat inputs and options. Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options. """ + # pyrefly: ignore # missing-module-attribute from .compile_fx import graph_returns_tuple assert graph_returns_tuple(gm), ( diff --git a/torch/_inductor/compile_fx_ext.py b/torch/_inductor/compile_fx_ext.py index 7fd976a05ed9..743819af7e67 100644 --- a/torch/_inductor/compile_fx_ext.py +++ b/torch/_inductor/compile_fx_ext.py @@ -620,6 +620,7 @@ class _OutOfProcessFxCompile(_SerializedFxCompile): if output.warning_replay: for w in output.warning_replay: + # pyrefly: ignore # no-matching-overload warnings.warn_explicit( message=w.message, category=w.category, diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 6342fc7e0fcd..474cd86eb362 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -170,6 +170,7 @@ class SubprocPool: log_path = config.get_worker_log_path() if log_path: + # pyrefly: ignore # bad-assignment self.log_file = open(log_path, "w") self.process = subprocess.Popen( @@ -204,6 +205,7 @@ class SubprocPool: self, job_fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs ) -> Future[_T]: if args or kwargs: + # pyrefly: ignore # bad-assignment job_fn = functools.partial(job_fn, *args, **kwargs) job_data = self.pickler.dumps(job_fn) future: Future[_T] diff --git a/torch/_inductor/compile_worker/tracked_process_pool.py b/torch/_inductor/compile_worker/tracked_process_pool.py index 36df56b963d6..040909fafec9 100644 --- a/torch/_inductor/compile_worker/tracked_process_pool.py +++ b/torch/_inductor/compile_worker/tracked_process_pool.py @@ -83,6 +83,7 @@ class TrackedProcessPoolExecutor(ProcessPoolExecutor): def submit( self, fn: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs ) -> Future[_R]: + # pyrefly: ignore # bad-argument-type f = super().submit(fn, *args, **kwargs) self._record_enqueue(f) return f diff --git a/torch/_inductor/compiler_bisector.py b/torch/_inductor/compiler_bisector.py index bdfd02f76682..b325bcaa5378 100644 --- a/torch/_inductor/compiler_bisector.py +++ b/torch/_inductor/compiler_bisector.py @@ -243,6 +243,7 @@ class CompilerBisector: lines = cls.read_lines_from_file(file_path) low = None high = None + # pyrefly: ignore # bad-assignment for line in reversed(lines): if line.startswith("low="): low = int(line.strip().split("=")[1]) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index d671963f7852..c487d259afca 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -407,6 +407,7 @@ reorder_iterative_debug_limit_to_reorder: Optional[int] = ( else int(env_str) ) sink_waits_iterative_debug_limit_to_sink: Optional[int] = ( + # pyrefly: ignore # unbound-name None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str) ) diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 4a94ea28908c..713326db862c 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -237,6 +237,7 @@ extern "C" __m512bh __avx512_bf16_chk_kernel(__m512 a, __m512 b) { """ @functools.cache # noqa: B019 + # pyrefly: ignore # bad-override def __bool__(self) -> bool: if super().__bool__(): if config.is_fbcode(): @@ -450,6 +451,7 @@ def get_isa_from_cpu_capability( "avx512": "avx512", } if capability in capability_to_isa_str.keys(): + # pyrefly: ignore # index-error isa_str = capability_to_isa_str[capability] if isa_str == "INVALID_VEC_ISA": return invalid_vec_isa diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 3b3dea909cd2..566db12e4929 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -407,6 +407,7 @@ def cudagraphify_impl( fn = align_inputs_from_check_idxs( fn, inputs_to_check=check_input_idxs, mutated_input_idxs=mutated_input_idxs ) + # pyrefly: ignore # unsupported-operation fn_cache[int_key] = fn return out @@ -922,6 +923,7 @@ class CUDAGraphNode: return None self.static_input_data_ptrs: InputList[Optional[int]] = [ + # pyrefly: ignore # bad-argument-type maybe_get_static_data_ptr(i, inputs, self.static_input_idxs) for i in range(len(inputs)) ] @@ -968,8 +970,10 @@ class CUDAGraphNode: self.expected_dead_indices_before_graph = different_indices rng_states = [inp for inp in inputs if isinstance(inp, torch.Generator)] + # pyrefly: ignore # bad-argument-type recording_inputs = self._allocate_and_copy_recording_inputs(inputs) # recording inputs will copy over memory, so we can free non recording inputs + # pyrefly: ignore # missing-attribute inputs.clear() del inputs @@ -1281,8 +1285,10 @@ class CUDAGraphNode: if not isinstance(static_outputs, (list, tuple)): static_outputs = (static_outputs,) + # pyrefly: ignore # bad-argument-type self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs) + # pyrefly: ignore # bad-return return static_outputs def _add_first_outputs( @@ -1676,6 +1682,7 @@ class CUDAGraphNode: for i, inp in enumerate(inputs): if not isinstance(inp, torch.Tensor): assert isinstance(inp, (int, torch.Generator)) + # pyrefly: ignore # bad-argument-type recording_inputs.append(inp) elif i not in self.static_input_idxs: # static_input does an allocation! @@ -1840,6 +1847,7 @@ def check_memory_pool( formatted = [] for dp, block in allocated_not_in_live_storages.items(): trace = format_tb(block.get("frames", [])) + # pyrefly: ignore # bad-argument-type formatted.append(f"Data Pointer: {dp}, history: \n{trace}") formatted_s = "\n".join(formatted) msg = ( @@ -2547,7 +2555,11 @@ class CUDAGraphTreeManager: live_storages_weak_refs: list[int] = [t() for t in live_storages_wrappers] # type: ignore[misc] ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation() torch._C._cuda_setCheckpointPoolState( - device, state, stale_storages, live_storages_weak_refs + device, + # pyrefly: ignore # bad-argument-type + state, + stale_storages, + live_storages_weak_refs, ) # NB: deduplicate aliased outputs diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index 60eaf7a84e6d..5dbe849af095 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -97,6 +97,7 @@ def draw_buffers( dtype = node.data.dtype metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type] + # pyrefly: ignore # missing-attribute node.meta["tensor_meta"] = metadata if print_graph: @@ -228,6 +229,7 @@ def update_orig_fx_node_name_to_buf_name( ) continue else: + # pyrefly: ignore # bad-argument-type, unsupported-operation assert len(children_nodes) == 1 and children_nodes[0] == node ir_node = node.node @@ -251,6 +253,7 @@ def get_node_name_to_buf_meta( if buf_name not in buf_name_to_n_node: buf_name_to_n_node[buf_name] = OrderedSet([node_name]) else: + # pyrefly: ignore # missing-attribute buf_name_to_n_node[buf_name].add(node_name) node_name_to_buf_meta = {} @@ -1146,9 +1149,11 @@ def set_kernel_post_grad_provenance_tracing( kernel_name, [] ) ) + # pyrefly: ignore # missing-attribute stack_traces_set.update(snode.node.get_stack_traces()) curr_node_info.extend( origin.name + # pyrefly: ignore # missing-attribute for origin in snode.node.origins if origin.name not in curr_node_info ) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index f9331170c1ae..2a17fa6d5643 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -544,6 +544,7 @@ def amax( keepdim: bool = False, ) -> torch.Tensor: if self.dtype == torch.bool: + # pyrefly: ignore # no-matching-overload return torch.any(self, dim=dim, keepdim=keepdim) return NotImplemented @@ -555,6 +556,7 @@ def amin( keepdim: bool = False, ) -> torch.Tensor: if self.dtype == torch.bool: + # pyrefly: ignore # no-matching-overload return torch.all(self, dim=dim, keepdim=keepdim) return NotImplemented @@ -1045,9 +1047,13 @@ def _max_pool_with_indices( if not stride: stride = kernel_size + # pyrefly: ignore # bad-assignment kernel_size = pad_listlike(kernel_size, dim) + # pyrefly: ignore # bad-assignment dilation = pad_listlike(dilation, dim) + # pyrefly: ignore # bad-assignment padding = pad_listlike(padding, dim) + # pyrefly: ignore # bad-assignment stride = pad_listlike(stride, dim) window_size = functools.reduce(operator.mul, kernel_size) @@ -1205,8 +1211,11 @@ def conv1d_to_conv2d( "Expect (N,C_in,L) and (C_out,C_in//groups,K)" ) + # pyrefly: ignore # bad-assignment stride = stride[0] + # pyrefly: ignore # bad-assignment padding = padding[0] + # pyrefly: ignore # bad-assignment dilation = dilation[0] # Unsqueeze to make input 2D: (N,C,L) -> (N,C,L,1) diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 21aad41c2c0c..0547b6b1db90 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -70,7 +70,9 @@ class Dep(abc.ABC): @dataclasses.dataclass(frozen=True) class MemoryDep(Dep): + # pyrefly: ignore # bad-override name: str + # pyrefly: ignore # bad-override index: sympy.Expr var_names: tuple[sympy.Symbol, ...] size: tuple[sympy.Expr, ...] @@ -306,11 +308,13 @@ class MemoryDep(Dep): @dataclasses.dataclass(frozen=True) class StarDep(Dep): + # pyrefly: ignore # bad-override name: str mode: Optional[str] = None # depends on the entire buffer @property + # pyrefly: ignore # bad-override def index(self) -> sympy.Expr: raise NotImplementedError("StarDep does not have an index") @@ -359,6 +363,7 @@ class StarDep(Dep): @dataclasses.dataclass(frozen=True) class WeakDep(Dep): # Fake dependency on unused buffer + # pyrefly: ignore # bad-override name: str # Buffer that is doing the mutation mutating_buf: str @@ -375,6 +380,7 @@ class WeakDep(Dep): return OrderedSet() @property + # pyrefly: ignore # bad-override def index(self) -> sympy.Expr: raise NotImplementedError("WeakDep does not have an index") @@ -662,8 +668,11 @@ def extract_read_writes( range_vars = [*itertools.chain.from_iterable(args)] return ReadWrites( + # pyrefly: ignore # missing-attribute OrderedSet(inner._reads), + # pyrefly: ignore # missing-attribute OrderedSet(inner._writes), + # pyrefly: ignore # missing-attribute inner._index_exprs, range_vars, var_ranges, diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index 4c30079549c5..2a15104e7162 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -58,6 +58,7 @@ def promote_types( ): dtype_prop_candidates = [] + # pyrefly: ignore # bad-assignment for arg in args: assert not isinstance(arg, str) if isinstance(arg, OpsValue): @@ -68,6 +69,7 @@ def promote_types( dtype_prop_candidates.append((type_to_dtype(type(arg)), True)) continue + # pyrefly: ignore # missing-attribute dtype_prop_candidates.append((arg.dtype, getattr(arg, "is_scalar", False))) dtype = get_promoted_dtype( diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index a46663ed8f8c..1dd25804ce23 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -132,6 +132,7 @@ class TritonMissing(ShortenTraceback): class GPUTooOldForTriton(ShortenTraceback): def __init__( self, + # pyrefly: ignore # not-a-type device_props: _CudaDeviceProperties, first_useful_frame: Optional[types.FrameType], ) -> None: diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py index 05222168095f..dd8af71c7678 100644 --- a/torch/_inductor/freezing.py +++ b/torch/_inductor/freezing.py @@ -153,6 +153,7 @@ class ErasedTensor(torch.Tensor): def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] erased_tensors = [ e + # pyrefly: ignore # bad-unpacking for e in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(e, ErasedTensor) ] @@ -177,6 +178,7 @@ def invalidate_eager_modules(): for attr_name, tensor in list( itertools.chain( mod.named_parameters(recurse=False), + # pyrefly: ignore # bad-argument-type mod.named_buffers(recurse=False), ) ): @@ -192,7 +194,9 @@ def discard_traced_gm_params(mod: torch.fx.GraphModule): with torch.utils._python_dispatch._disable_current_modes(): for attr_name, tensor in list( itertools.chain( - mod.named_parameters(recurse=False), mod.named_buffers(recurse=False) + mod.named_parameters(recurse=False), + # pyrefly: ignore # bad-argument-type + mod.named_buffers(recurse=False), ) ): with torch._dispatch.python.no_python_dispatcher(): diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 2e67a0d920ce..69216c8f5c5e 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -108,10 +108,12 @@ class TypeExemplars: """ Return an example of a class. """ + # pyrefly: ignore # bad-argument-type, bad-argument-count return TypeExemplars.TYPE_EXEMPLARS.get(t.__name__, None) @staticmethod def contains(t: type[T]) -> bool: + # pyrefly: ignore # bad-argument-type, bad-argument-count return t.__name__ in TypeExemplars.TYPE_EXEMPLARS diff --git a/torch/_inductor/fx_passes/b2b_gemm.py b/torch/_inductor/fx_passes/b2b_gemm.py index a87c86fe9e52..91502e963964 100644 --- a/torch/_inductor/fx_passes/b2b_gemm.py +++ b/torch/_inductor/fx_passes/b2b_gemm.py @@ -580,6 +580,7 @@ def tuned_b2b_gemm( # match the inner mm of a potential b2b_gemm @register_graph_pattern( CallFunction(torch.ops.aten.mm, Arg(), Arg()), + # pyrefly: ignore # bad-argument-type pass_dict=B2B_GEMM_PASS, ) def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) -> None: @@ -699,22 +700,33 @@ def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) -> new_input_anchor = new_node if node is f_node: new_output_anchor = new_node + # pyrefly: ignore # unbound-name if new_input_anchor is not new_output_anchor: # subgraph is non-trivial # update the input node + # pyrefly: ignore # unbound-name with new_graph.inserting_before(new_input_anchor): new_input_node = new_graph.placeholder(name="subgraph_input") + # pyrefly: ignore # unbound-name new_input_node.meta.update(new_input_anchor.meta) + # pyrefly: ignore # unbound-name new_input_anchor.replace_all_uses_with(new_input_node) + # pyrefly: ignore # unbound-name new_graph.erase_node(new_input_anchor) # add the output node + # pyrefly: ignore # unbound-name new_output_node = new_graph.output(new_output_anchor) + # pyrefly: ignore # unbound-name new_output_node.meta.update(new_output_anchor.meta) else: # subgraph is trivial, e.g. (A @ (B @ C)) # update the input node + # pyrefly: ignore # unbound-name with new_graph.inserting_before(new_input_anchor): new_input_node = new_graph.placeholder(name="subgraph_input") + # pyrefly: ignore # unbound-name new_input_node.meta.update(new_input_anchor.meta) + # pyrefly: ignore # unbound-name new_input_anchor.replace_all_uses_with(new_input_node) + # pyrefly: ignore # unbound-name new_graph.erase_node(new_input_anchor) # update the output node (don't use new_output_anchor since it has been erased) new_output_node = new_graph.output(new_input_node) diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index ccea7d7e70af..8f5cc7bc5d2b 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -215,6 +215,7 @@ def _fuse_allreduce_by_concat( # Move the fused all_reduce and its args to right after the input node nodes_to_move = cat_inputs + [cat_node, div_node, fused_comm_node, fused_wait_node] + # pyrefly: ignore # bad-argument-type move_block_after(nodes_to_move, last_input_node) return CommBlock( @@ -307,6 +308,7 @@ def _scatter_fused_allreduce_waits( # in orig_comm_blocks. This index will be later used to determine what users # nodes need to be move to maintain a correct topological sort order. last_wait_node_idx = 0 + # pyrefly: ignore # bad-assignment for node in graph.nodes: last_wait_node_idx = max( node_indices.get(node, last_wait_node_idx), last_wait_node_idx @@ -356,6 +358,7 @@ def _scatter_fused_allreduce_waits( user_node = nodes.popleft() if not isinstance(user_node, fx.Node): continue + # pyrefly: ignore # unsupported-operation if node_indices[user_node] < last_wait_node_idx: incorrect_order_nodes.append(user_node) nodes.extend(list(user_node.users)) diff --git a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py index 31c6dae82fdb..127384c87f10 100644 --- a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +++ b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -225,6 +225,7 @@ def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): if should_decompose_bmm(mat1, mat2): counters["inductor"]["decompose_bmm"] += 1 + # pyrefly: ignore # bad-argument-type match.replace_by_example(repl, [mat1, mat2]) print_decompose_pattern(match, [mat1, mat2]) realize_inputs([mat1, mat2]) @@ -248,6 +249,7 @@ def decompose_addmm( if should_decompose_mm(mat2, mat3): counters["inductor"]["decompose_addmm"] += 1 + # pyrefly: ignore # bad-argument-type match.replace_by_example(repl, [mat1, mat2, mat3]) print_decompose_pattern(match, [mat1, mat2, mat3]) realize_inputs([mat1, mat2, mat3]) @@ -268,6 +270,7 @@ def decompose_mm( if should_decompose_mm(mat1, mat2): counters["inductor"]["decompose_mm"] += 1 + # pyrefly: ignore # bad-argument-type match.replace_by_example(repl, [mat1, mat2]) print_decompose_pattern(match, [mat1, mat2]) realize_inputs([mat1, mat2]) diff --git a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py index 0e647e37cd34..78cd317284d2 100644 --- a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +++ b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -144,6 +144,7 @@ def efficient_conv_bn_eval_decomposed( torch.nn.functional.batch_norm, ] ), + # pyrefly: ignore # bad-argument-type pass_dict=efficient_conv_bn_eval_pass, extra_check=lambda match: not inductor_config.freezing and inductor_config.efficient_conv_bn_eval_fx_passes, @@ -235,6 +236,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs torch.ops.aten.batch_norm.default, ] ), + # pyrefly: ignore # bad-argument-type pass_dict=efficient_conv_bn_eval_pass, extra_check=lambda match: not inductor_config.freezing and inductor_config.efficient_conv_bn_eval_fx_passes, @@ -330,6 +332,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa nn.SyncBatchNorm, ], ), + # pyrefly: ignore # bad-argument-type pass_dict=efficient_conv_bn_eval_pass, extra_check=lambda match: not inductor_config.freezing and inductor_config.efficient_conv_bn_eval_fx_passes, diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py index 26256b5504d7..c3eed5660479 100644 --- a/torch/_inductor/fx_passes/freezing_patterns.py +++ b/torch/_inductor/fx_passes/freezing_patterns.py @@ -107,6 +107,7 @@ def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_numb return register_graph_pattern( pattern, extra_check=extra_check, + # pyrefly: ignore # bad-argument-type pass_dict=pass_patterns[pass_number], ) @@ -115,6 +116,7 @@ def register_binary_folding_pattern(pattern, extra_check=_return_true): return register_graph_pattern( pattern, extra_check=extra_check, + # pyrefly: ignore # bad-argument-type pass_dict=binary_folding_pass, ) @@ -202,10 +204,13 @@ def addmm_patterns_init(): return mm.tensor_split([n1, n1 + n2], dim=-1) register_replacement( + # pyrefly: ignore # bad-argument-type int8_woq_fusion_pattern, + # pyrefly: ignore # bad-argument-type int8_woq_fusion_replacement, [val(), val(), val(), val(), scale(), scale(), scale()], fwd_only, + # pyrefly: ignore # bad-argument-type pass_patterns[0], extra_check=check_int8_woq_concat_linear_weights, exclusive_arg_names=("w1", "w2", "w3", "s1", "s2", "s3"), @@ -220,10 +225,13 @@ def addmm_patterns_init(): return mm.chunk(3, dim=1) register_replacement( + # pyrefly: ignore # bad-argument-type matmul_fuse_pattern, + # pyrefly: ignore # bad-argument-type matmul_replacement, [val(), val(), val(), val()], fwd_only, + # pyrefly: ignore # bad-argument-type pass_patterns[0], extra_check=check_concat_weights, exclusive_arg_names=("w1", "w2", "w3"), @@ -238,10 +246,13 @@ def addmm_patterns_init(): return mm.chunk(2, dim=1) register_replacement( + # pyrefly: ignore # bad-argument-type matmul_fuse_pattern_two, + # pyrefly: ignore # bad-argument-type matmul_replacement_two, [val(), val(), val()], fwd_only, + # pyrefly: ignore # bad-argument-type pass_patterns[0], extra_check=check_concat_weights, exclusive_arg_names=("w1", "w2"), @@ -260,10 +271,13 @@ def addmm_patterns_init(): return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1) register_replacement( + # pyrefly: ignore # bad-argument-type addmm_fuse_pattern_second, + # pyrefly: ignore # bad-argument-type addmm_fuse_replacement_second, [val() for _ in range(7)], fwd_only, + # pyrefly: ignore # bad-argument-type pass_patterns[0], extra_check=check_concat_weights, exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"), @@ -280,6 +294,7 @@ def same_dtype(match): Ignored(), KeywordArg("dtype"), ), + # pyrefly: ignore # bad-argument-type pass_dict=pass_patterns[0], extra_check=same_dtype, ) diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index f081374585ee..743d9a1b85a0 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -840,6 +840,7 @@ class BatchLayernormFusion(BatchFusion): ) update_pointwise_example_value( batch_layer_norm, + # pyrefly: ignore # missing-attribute stack_weight.meta["example_value"], previous_batch_layer_norm_meta, torch.mul, @@ -850,28 +851,33 @@ class BatchLayernormFusion(BatchFusion): ) update_pointwise_example_value( batch_layer_norm, + # pyrefly: ignore # missing-attribute stack_bias.meta["example_value"], previous_batch_layer_norm_meta, torch.add, ) elif group_weights is not None and group_biases is None: previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + # pyrefly: ignore # not-callable batch_layer_norm = graph.call_function( torch.mul, args=(stack_weight, batch_layer_norm) ) update_pointwise_example_value( batch_layer_norm, + # pyrefly: ignore # missing-attribute stack_weight.meta["example_value"], previous_batch_layer_norm_meta, torch.mul, ) elif group_weights is None and group_biases is not None: previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + # pyrefly: ignore # not-callable batch_layer_norm = graph.call_function( torch.add, args=(stack_bias, batch_layer_norm) ) update_pointwise_example_value( batch_layer_norm, + # pyrefly: ignore # missing-attribute stack_bias.meta["example_value"], previous_batch_layer_norm_meta, torch.add, diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index a95bd1b203fd..aa06049a9c65 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -517,6 +517,7 @@ def canonicalize_quant_mapping(gm: torch.fx.GraphModule): invoke_quant_replacement = graph.call_function( torch._higher_order_ops.invoke_quant, (subgraph, *args), + # pyrefly: ignore # bad-argument-type kwargs, ) invoke_quant_replacement.meta.update(subgraph.meta) @@ -633,6 +634,7 @@ def joint_graph_passes(graph: torch.fx.GraphModule): device=KeywordArg("device"), requires_grad=KeywordArg("requires_grad"), ), + # pyrefly: ignore # bad-argument-type pass_dict=patterns, ) def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad): @@ -686,6 +688,7 @@ def fix_iota_device(match: Match, length, start, step, dtype, device, requires_g ), KeywordArg("dtype2"), ), + # pyrefly: ignore # bad-argument-type pass_dict=patterns, ) def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype): @@ -748,6 +751,7 @@ def definitely_equal( @register_graph_pattern( CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")), + # pyrefly: ignore # bad-argument-type pass_dict=patterns, ) def pointless_view(match: Match, arg, size): @@ -765,6 +769,7 @@ def pointless_view(match: Match, arg, size): CallFunction(aten.view.default, KeywordArg("arg"), KeywordArg("size1")), KeywordArg("size2"), ), + # pyrefly: ignore # bad-argument-type pass_dict=patterns, ) def pointless_view_pair(match: Match, arg, size1, size2): @@ -785,6 +790,7 @@ def pointless_view_pair(match: Match, arg, size1, size2): CallFunction(aten.permute.default, KeywordArg("arg"), KeywordArg("perm1")), KeywordArg("perm2"), ), + # pyrefly: ignore # bad-argument-type pass_dict=patterns, ) def pointless_permute_pair(match: Match, arg, perm1, perm2): @@ -805,6 +811,7 @@ def pointless_permute_pair(match: Match, arg, perm1, perm2): Arg(), Arg(), ), + # pyrefly: ignore # bad-argument-type pass_dict=patterns, ) def bmm_to_mm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): @@ -818,6 +825,7 @@ def bmm_to_mm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): and statically_known_true(mat1.meta["val"].shape[0] == 1) and statically_known_true(mat2.meta["val"].shape[0] == 1) ): + # pyrefly: ignore # bad-argument-type match.replace_by_example(repl, [mat1, mat2]) @@ -907,14 +915,17 @@ def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): inp = inp * sign max_ = torch.amax(inp, dim=dim, keepdim=keepdim) + # pyrefly: ignore # unsupported-operation return (inp - max_) * (sign * other) + # pyrefly: ignore # bad-argument-type match.replace_by_example(repl, [inp, other]) for reverse, to_dtype in itertools.product((False, True), repeat=2): register_graph_pattern( _partial_softmax_pattern(aten.mul.Tensor, reverse=reverse, to_dtype=to_dtype), + # pyrefly: ignore # bad-argument-type pass_dict=pass_patterns[1], extra_check=_other_is_broadcasted_in_dim, )(mul_softmax_pattern) @@ -934,14 +945,17 @@ def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): inp = inp * sign max_ = torch.amax(inp, dim=dim, keepdim=keepdim) + # pyrefly: ignore # unsupported-operation return (inp - max_) / (sign * other) + # pyrefly: ignore # bad-argument-type match.replace_by_example(repl, [inp, other]) for to_dtype in (False, True): register_graph_pattern( _partial_softmax_pattern(aten.div.Tensor, to_dtype=to_dtype), + # pyrefly: ignore # bad-argument-type pass_dict=pass_patterns[1], extra_check=_other_is_broadcasted_in_dim, )(div_softmax_pattern) diff --git a/torch/_inductor/fx_passes/misc_patterns.py b/torch/_inductor/fx_passes/misc_patterns.py index d2c8068f130c..538a2ca2c43b 100644 --- a/torch/_inductor/fx_passes/misc_patterns.py +++ b/torch/_inductor/fx_passes/misc_patterns.py @@ -44,10 +44,13 @@ def _misc_patterns_init(): ) register_replacement( + # pyrefly: ignore # bad-argument-type randperm_index_add_pattern, + # pyrefly: ignore # bad-argument-type randperm_index_add_replacement, [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], fwd_only, + # pyrefly: ignore # bad-argument-type [post_grad_patterns, joint_graph_patterns], ) @@ -60,10 +63,13 @@ def _misc_patterns_init(): return torch.ops.aten._unsafe_index(x, (index,)), index register_replacement( + # pyrefly: ignore # bad-argument-type randperm_index_pattern, + # pyrefly: ignore # bad-argument-type randperm_index_replacement, [torch.empty(4, 8, device=device)], fwd_only, + # pyrefly: ignore # bad-argument-type [post_grad_patterns, joint_graph_patterns], scalar_workaround={"slice_shape": 42}, ) diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 868eb74824dd..99e8cfeff793 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -712,6 +712,7 @@ if torch._C._has_mkldnn: if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes): return False if any( + # pyrefly: ignore # missing-attribute n.args[other_index].op in ["placeholder", "output"] for n in binary_nodes ): diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 69cc0bb476b9..20d4abda9652 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -68,6 +68,7 @@ def get_hint(x: Union[int, torch.SymInt]) -> Optional[int]: def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]: with dynamo_timed("collective_compute_do_bench"): return functools.partial( + # pyrefly: ignore # bad-argument-type torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu, warmup=5, ) diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 42ee33a367f0..f58678e7651e 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -97,6 +97,7 @@ def should_pad_common( if isinstance(x, int): continue elif utils.is_symbolic(x): + # pyrefly: ignore # missing-attribute if not x.node.has_hint(): return False symbolic_cnt += 1 @@ -106,6 +107,7 @@ def should_pad_common( if symbolic_cnt == len(t.size()): return False return all( + # pyrefly: ignore # missing-attribute isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint()) for x in t.stride() ) @@ -399,6 +401,7 @@ def should_pad_bench(*args: Any, **kwargs: Any) -> bool: def get_do_bench() -> Callable[[Callable[[], Any]], float]: with dynamo_timed("pad_mm_benchmark_get_do_bench"): return functools.partial( + # pyrefly: ignore # bad-argument-type torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu, warmup=5, ) @@ -483,6 +486,7 @@ def _should_pad_bench( def realize_tensor(t): if isinstance(t, FakeTensor): size_hints = realize_symbols(t.size()) + # pyrefly: ignore # bad-argument-type stride_hint = realize_symbols(t.stride()) real_size = ( sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1 @@ -918,6 +922,7 @@ def _pad_mm_init() -> None: replacement, args, joint_fwd_bwd, + # pyrefly: ignore # bad-argument-type patterns, extra_check=extra_check, scalar_workaround=workaround, @@ -929,6 +934,7 @@ def _pad_mm_init() -> None: replacement, args, fwd_only, + # pyrefly: ignore # bad-argument-type patterns, extra_check=extra_check, scalar_workaround=workaround, diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index a07f3cfa1b14..db9f6f8563e6 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -338,6 +338,7 @@ def decompose_map_to_while_loop(gm: torch.fx.GraphModule): @register_graph_pattern( CallFunctionVarArgs(torch.ops.higher_order.map_impl), + # pyrefly: ignore # bad-argument-type pass_dict=graph_pass, ) def _(match: Match, *args, **kwargs): @@ -525,6 +526,7 @@ def decompose_scan_to_while_loop(gm: torch.fx.GraphModule): @register_graph_pattern( CallFunctionVarArgs(torch.ops.higher_order.scan), + # pyrefly: ignore # bad-argument-type pass_dict=graph_pass, ) def _(match: Match, *args, **kwargs): @@ -658,11 +660,14 @@ def lazy_init(): # pass since otherwise there will be perf/peak-memory regression: # https://github.com/pytorch/pytorch/issues/148141 register_replacement( + # pyrefly: ignore # bad-argument-type prepare_softmax_pattern, + # pyrefly: ignore # bad-argument-type prepare_softmax_replacement, [torch.empty(4, 8)], scalar_workaround=dict(dim=-1), trace_fn=fwd_only, + # pyrefly: ignore # bad-argument-type pass_dicts=pass_patterns[1], extra_check=prepare_softmax_extra_check, ) @@ -723,7 +728,9 @@ def register_lowering_pattern( Register an aten to inductor IR replacement pattern """ return pattern_matcher.register_lowering_pattern( - pattern, extra_check, pass_dict=pass_patterns[pass_number] + pattern, + extra_check, + pass_dict=pass_patterns[pass_number], ) @@ -820,6 +827,7 @@ def scatter_upon_const_tensor( device = selector.device if hasattr(selector, "device") else torch.device("cpu") return torch.empty(shape, dtype=dtype, device=device) + # pyrefly: ignore # bad-assignment metrics.num_matches_for_scatter_upon_const_tensor += 1 selector_loader = selector.make_loader() @@ -871,6 +879,7 @@ def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): KeywordArg("dim"), _users=MULTIPLE, ), + # pyrefly: ignore # bad-argument-type pass_dict=pass_patterns[1], ) def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim): @@ -890,6 +899,7 @@ def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, # only replace the output node, not all nodes match.nodes = [match.output_node()] + # pyrefly: ignore # bad-argument-type match.replace_by_example(repl, list(shape)) @@ -1207,6 +1217,7 @@ def decompose_triton_kernel_wrapper_functional(graph): @register_graph_pattern( CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional), + # pyrefly: ignore # bad-argument-type pass_dict=graph_pass, ) def _(match: Match, *args, **kwargs): @@ -1223,6 +1234,7 @@ def decompose_triton_kernel_wrapper_functional(graph): args, kwargs = pytree.tree_unflatten(flat_args, spec) return (triton_kernel_wrapper_functional_dense(*args, **kwargs),) + # pyrefly: ignore # bad-argument-type match.replace_by_example(decomp, flat_args, run_functional_passes=False) graph_pass.apply(graph) @@ -1246,6 +1258,7 @@ def decompose_auto_functionalized(graph): @register_graph_pattern( CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized), + # pyrefly: ignore # bad-argument-type pass_dict=graph_pass, ) def _(match: Match, *args, **kwargs): @@ -1266,10 +1279,12 @@ def decompose_auto_functionalized(graph): mode = args[0] return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs) + # pyrefly: ignore # bad-argument-type match.replace_by_example(decomp, flat_args, run_functional_passes=False) @register_graph_pattern( CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2), + # pyrefly: ignore # bad-argument-type pass_dict=graph_pass, ) def _(match: Match, *args, **kwargs): @@ -1310,6 +1325,7 @@ def decompose_auto_functionalized(graph): mutable_op, only_clone_these_bases, **kwargs ) + # pyrefly: ignore # bad-argument-type match.replace_by_example(decomp, flat_args, run_functional_passes=False) graph_pass.apply(graph) @@ -1474,6 +1490,7 @@ def should_prefer_unfused_addmm(match): @register_graph_pattern( CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()), + # pyrefly: ignore # bad-argument-type pass_dict=pass_patterns[2], extra_check=should_prefer_unfused_addmm, ) @@ -1481,6 +1498,7 @@ def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): def repl(inp, x1, x2): return x1 @ x2 + inp + # pyrefly: ignore # bad-argument-type match.replace_by_example(repl, [inp, mat1, mat2]) @@ -1514,6 +1532,7 @@ def is_valid_addmm_fusion(match): CallFunction(aten.mm, Arg(), Arg()), KeywordArg("inp"), ), + # pyrefly: ignore # bad-argument-type pass_dict=pass_patterns[2], extra_check=is_valid_addmm_fusion, ) @@ -1523,6 +1542,7 @@ def is_valid_addmm_fusion(match): KeywordArg("inp"), CallFunction(aten.mm, Arg(), Arg()), ), + # pyrefly: ignore # bad-argument-type pass_dict=pass_patterns[2], extra_check=is_valid_addmm_fusion, ) @@ -1552,7 +1572,8 @@ def register_partial_reduction_pattern(): full_reduc = CallFunction([red_op, equiv_red[red_op]], inp) @register_graph_pattern( - MultiOutputPattern([partial_reduc, full_reduc]), pass_dict=pass_patterns[2] + MultiOutputPattern([partial_reduc, full_reduc]), + pass_dict=pass_patterns[2], ) def reuse_partial(match, input, reduced_dims, keepdim): partial_red, full_red = match.output_nodes() @@ -1728,6 +1749,7 @@ class ConstructorMoverPass: pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs)) + # pyrefly: ignore # redundant-condition if cpu_count: cpu_indeg[node] = cpu_count diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 3b851b0e27ae..597013f6233c 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -509,6 +509,7 @@ def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModul conv = conv_bn_fusion.conv_module bn = conv_bn_fusion.bn_module + # pyrefly: ignore # bad-argument-type fused_conv = fuse_conv_bn_eval(conv, bn) for bn_node in bn_nodes: replace_node_module(bn_node.args[0], modules, fused_conv) @@ -596,8 +597,11 @@ def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModul fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( fused_conv.weight, fused_conv.bias, + # pyrefly: ignore # bad-argument-type bn_running_mean, + # pyrefly: ignore # bad-argument-type bn_running_var, + # pyrefly: ignore # bad-argument-type bn_eps, bn_weight, bn_bias, diff --git a/torch/_inductor/fx_passes/replace_random.py b/torch/_inductor/fx_passes/replace_random.py index de533679bbe6..4c7f8887f7ae 100644 --- a/torch/_inductor/fx_passes/replace_random.py +++ b/torch/_inductor/fx_passes/replace_random.py @@ -88,9 +88,13 @@ def get_device(device): return torch.empty([]).device # default device +# pyrefly: ignore # bad-argument-type @register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns) +# pyrefly: ignore # bad-argument-type @register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns) +# pyrefly: ignore # bad-argument-type @register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns) +# pyrefly: ignore # bad-argument-type @register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns) def replace_random( match: Match, @@ -120,9 +124,11 @@ def replace_random( match.output_node().target.overloadpacket # type: ignore[union-attr] ] # type: ignore[union-attr] device = get_device(device) + # pyrefly: ignore # bad-argument-type match.replace_by_example(replacement, [size]) +# pyrefly: ignore # bad-argument-type @register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns) def replace_randint( match: Match, @@ -140,4 +146,5 @@ def replace_randint( return result.to(dtype) device = get_device(device) + # pyrefly: ignore # bad-argument-type match.replace_by_example(replacement, [low, high, size]) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 899960ac435c..9b0f5956cce6 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -291,6 +291,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs): log.debug("example value absent for node: %s", input) return ndim = input.meta["example_value"].ndim + # pyrefly: ignore # unsupported-operation if dim < 0: # Normalize unbind dim dim += ndim with graph.inserting_after(node): @@ -340,6 +341,7 @@ def normalize_cat_default(match: Match, *args, **kwargs): ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors ) + # pyrefly: ignore # unsupported-operation if cat_dim < 0: # Normalize cat dim cat_dim += ndim @@ -947,6 +949,7 @@ class SplitCatSimplifier: if isinstance(user_input, tuple): # Find the correct new getitem (present in split_items) new_user_inputs.append( + # pyrefly: ignore # bad-argument-type split_items[ split_ranges.index( ( @@ -997,6 +1000,7 @@ class SplitCatSimplifier: for user_input_new, transform_param in zip( user_inputs_new, transform_params ): + # pyrefly: ignore # bad-argument-type if not is_node_meta_valid(user_input_new): log.debug("example value absent for node: %s", user_input_new) return @@ -1011,6 +1015,7 @@ class SplitCatSimplifier: stack_dim is None or stack_dim == unsqueeze_params[0] ): to_stack.append(user_input_new) + # pyrefly: ignore # missing-attribute to_stack_meta.append(user_input_new.meta["example_value"]) stack_dim = unsqueeze_params[0] continue @@ -1031,10 +1036,12 @@ class SplitCatSimplifier: if unsqueeze_params: to_stack.append(user_input_new) stack_dim = unsqueeze_params[0] + # pyrefly: ignore # missing-attribute to_stack_meta.append(user_input_new.meta["example_value"]) continue if unflatten_params: + # pyrefly: ignore # missing-attribute user_input_new_meta = user_input_new.meta["example_value"] user_input_new = graph.call_function( torch.unflatten, args=(user_input_new, *unflatten_params) @@ -1044,6 +1051,7 @@ class SplitCatSimplifier: *unflatten_params, # type: ignore[arg-type] ) if movedim_params: + # pyrefly: ignore # missing-attribute user_input_new_meta = user_input_new.meta["example_value"] user_input_new = graph.call_function( torch.movedim, args=(user_input_new, *movedim_params) @@ -1053,6 +1061,7 @@ class SplitCatSimplifier: *movedim_params, # type: ignore[arg-type] ) if flatten_params: + # pyrefly: ignore # missing-attribute user_input_new_meta = user_input_new.meta["example_value"] user_input_new = graph.call_function( torch.flatten, args=(user_input_new, *flatten_params) @@ -1063,6 +1072,7 @@ class SplitCatSimplifier: ) user_inputs_new_transformed.append(user_input_new) user_inputs_new_transformed_meta.append( + # pyrefly: ignore # missing-attribute user_input_new.meta["example_value"] ) if to_stack: @@ -1422,6 +1432,7 @@ def simplify_split_cat(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return split_node = next(node for node in match.nodes if node.target == torch.split) + # pyrefly: ignore # bad-argument-type SplitCatSimplifier().simplify(match.graph, split_node, split_sections) @@ -1490,6 +1501,7 @@ def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: list[int]) - for i in range(len(split_node.args[1])): # type: ignore[arg-type] if i in indices: fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + # pyrefly: ignore # bad-return return fused_tensor_size @@ -1966,6 +1978,7 @@ def normalize_cat_default_aten(match: Match, *args, **kwargs): assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) + # pyrefly: ignore # unsupported-operation if cat_dim < 0: # Normalize cat dim cat_dim += ndim @@ -3031,5 +3044,6 @@ def replace_einsum_to_pointwise(match: Match, *args, **kwargs): einsum_node = match.nodes[0] input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2) if should_replace_einsum(einsum_node): + # pyrefly: ignore # bad-argument-type match.replace_by_example(repl, [input, weights]) counters[backend]["einsum_to_pointwise_pass"] += 1 diff --git a/torch/_inductor/fx_utils.py b/torch/_inductor/fx_utils.py index c754c0324868..fdc60e19efb6 100644 --- a/torch/_inductor/fx_utils.py +++ b/torch/_inductor/fx_utils.py @@ -238,6 +238,7 @@ class FakeTensorUpdater: symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor) ): # Refresh the bindings to the new symbols + # pyrefly: ignore # unbound-name node.meta["unbacked_bindings"] = symbol_to_path existing_storages[get_node_storage(node)] += 1 diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 803e0c298265..45d2c3134e48 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -586,6 +586,7 @@ class GraphLowering(torch.fx.Interpreter): isinstance(node, ir.ComputedBuffer) and node.name in self.buffer_to_padded_size ): + # pyrefly: ignore # index-error return self.buffer_to_padded_size[node.name] else: return node.get_size() @@ -1116,6 +1117,7 @@ class GraphLowering(torch.fx.Interpreter): self.constants[name].to(device_override), ) + # pyrefly: ignore # bad-override def placeholder( self, target: str, # type: ignore[override] @@ -1340,6 +1342,7 @@ class GraphLowering(torch.fx.Interpreter): """ return len(t.shape) == 1 and t.shape[0] <= 8 + # pyrefly: ignore # bad-override def get_attr( self, target: str, # type: ignore[override] @@ -1397,6 +1400,7 @@ class GraphLowering(torch.fx.Interpreter): def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn: raise AssertionError + # pyrefly: ignore # bad-override def output( self, target: str, # type: ignore[override] @@ -1641,7 +1645,12 @@ class GraphLowering(torch.fx.Interpreter): inp_args = eager_input_vals[0] inp_kwargs = eager_input_vals[1] args, kwargs = constrain_to_fake_tensors( - args, kwargs, inp_args, inp_kwargs + # pyrefly: ignore # unbound-name + args, + # pyrefly: ignore # unbound-name + kwargs, + inp_args, + inp_kwargs, ) else: args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index] @@ -1737,7 +1746,9 @@ class GraphLowering(torch.fx.Interpreter): # require_exact_strides to handle views. But ultimately it's better to require # the right strides at the tensor definition. if n.meta["val"]._is_view() or isinstance( - result.data, ir.BaseView + # pyrefly: ignore # missing-attribute + result.data, + ir.BaseView, ): result = ir.ExternKernel.require_stride_order( result, @@ -1815,6 +1826,7 @@ class GraphLowering(torch.fx.Interpreter): ), ) if user.op == "output": + # pyrefly: ignore # missing-attribute if isinstance(result.data.data, (Pointwise, Reduction)): result.realize() @@ -2163,6 +2175,7 @@ class GraphLowering(torch.fx.Interpreter): continue dynamic_grid = True new_grid.append(grid_outputs[visited_grids[val]]) + # pyrefly: ignore # bad-argument-type new_grids.append(tuple(new_grid)) if dynamic_grid: @@ -2184,6 +2197,7 @@ class GraphLowering(torch.fx.Interpreter): x: Union[torch.SymInt, torch.SymFloat, torch.Tensor], ) -> Union[int, float, torch.Tensor]: if x is None: + # pyrefly: ignore # bad-return return None elif isinstance(x, (torch.SymInt, torch.SymFloat)): # Need concrete value to run dynamic shapes and tune the result diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 0dc0a00412a8..a74540acc2ef 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -333,6 +333,7 @@ class IndexPropagation(DefaultHandler): for k, v in self.indirect_var_ranges.items() ), ) + # pyrefly: ignore # bad-argument-type return statically_known_true(self.shape_env, e, self.axioms, var_to_range) def indirect_indexing( diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index f7808f2de74b..786f6a0a204a 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -379,6 +379,7 @@ def ir_node_to_tensor( dtype = x.get_dtype() device = x.get_device() size = convert_shape_to_symint(size) + # pyrefly: ignore # bad-assignment stride = convert_shape_to_symint(stride) with V.graph.sizevars.shape_env.suppress_guards(): t = torch.empty_strided( @@ -406,6 +407,7 @@ def get_device_type( return x.type elif isinstance(x, (IRNode, OutputSpec)): return get_device_type(x.get_device()) + # pyrefly: ignore # bad-argument-type assert_never(f"get_device_type({x}: {type(x).__name__})") @@ -614,7 +616,9 @@ class IRNode: else: pre_grad_nodes = ( torch._inductor.debug._inductor_post_to_pre_grad_nodes.get( - "postToPre", {} + "postToPre", + {}, + # pyrefly: ignore # missing-attribute ).get(node.name, []) ) if not isinstance(pre_grad_nodes, list): @@ -650,6 +654,7 @@ class IRNode: lines = list(lines) + list(self.common_repr(shorten)) lines = list(map(str, lines)) if multiline: + # pyrefly: ignore # no-matching-overload new_lines = indent(",\n".join(lines)) return f"{type(self).__name__}(\n{new_lines}\n)" else: @@ -1481,6 +1486,7 @@ class Reduction(Loops): return fn @classmethod + # pyrefly: ignore # bad-override def create( cls, device: torch.device, @@ -2422,6 +2428,7 @@ class Scan(Loops): scan_type = Scan if num_splits > 1: supports_split = ( + # pyrefly: ignore # unsupported-operation torch.version.hip is None or (has_triton and triton_version >= "3.3.0") ) and (len(dtypes) == 1) if not supports_split: @@ -2882,6 +2889,7 @@ class ExpandView(BaseView): # NB: new_size[i] == old_size[i] is expected to already be # guarded because the meta formula was expected to have taught # us this equality. + # pyrefly: ignore # unsupported-operation assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, ( f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}" ) @@ -3608,6 +3616,7 @@ class Layout(OutputSpec): ) -> None: if stride is None: stride = FlexibleLayout.contiguous_strides(size) + # pyrefly: ignore # read-only self.device = device self.dtype = dtype assert len(size) == len(stride), f"size={size}, stride={stride}" @@ -3791,6 +3800,7 @@ class Layout(OutputSpec): # [25, 25, 5, 1]. return in_strides + # pyrefly: ignore # bad-assignment metrics.num_comprehensive_padding += 1 return new_strides @@ -4840,6 +4850,7 @@ class TemplateBuffer(OperationBuffer): def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any: assert len(rindex) == 0 + # pyrefly: ignore # missing-attribute return ops.load(inp.get_name(), indexer(index)) deps.reads |= dependencies.extract_read_writes( @@ -5163,6 +5174,7 @@ class CppTemplateBuffer(TemplateBuffer): def get_layout(self) -> Layout: if isinstance(self.layout, MultiOutputLayout): assert isinstance(self.outputs, Iterable), type(self.outputs) + # pyrefly: ignore # index-error first_output = self.outputs[0] assert isinstance(first_output, Buffer), type(first_output) layout = first_output.layout @@ -5479,6 +5491,7 @@ class ConcatKernel(NopKernel): # ExternKernelAlloc has specific requirements for output layout, should create a copy assert hasattr(src.data, "layout") if cls.can_realize_into_without_copy(src, dst): + # pyrefly: ignore # missing-attribute src.data.layout = NonOwningLayout(dst) return src.data # introduce a copy @@ -7160,6 +7173,7 @@ class IndexPutFallback(ExternKernel): ) -> None: self.indices = indices valid_indices = [i for i in indices if i is not None] + # pyrefly: ignore # bad-argument-type tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] cpp_kernel_name = "aoti_torch_index_put_out" super().__init__( @@ -7532,6 +7546,7 @@ class FallbackKernel(ExternKernelAlloc): add_alias(optional_tensor_arg) else: assert library_utils.is_tensor_like_type(info.type) + # pyrefly: ignore # bad-argument-type add_alias(arg) for info, arg in torch._library.utils.zip_schema(schema, args, kwargs): @@ -7963,6 +7978,7 @@ class FallbackKernel(ExternKernelAlloc): packed.outputs = tuple(outputs) else: packed.outputs = [outputs] + # pyrefly: ignore # bad-return return outputs def apply_constraint(self) -> None: @@ -8409,6 +8425,7 @@ class InvokeSubgraph(ExternKernel): # Realize the inputs. Also intermediates can have different strides than # the inputs of the subgraph. So, force the intermediates to have same # strides as that of subgraph inputs. + # pyrefly: ignore # annotation-mismatch operands: list[IRNode] = [cls.realize_input(x) for x in operands] new_operands: list[IRNode] = [] @@ -8420,6 +8437,7 @@ class InvokeSubgraph(ExternKernel): constrain_to_fake_tensor(operand, fake_operands[idx]) ) + # pyrefly: ignore # bad-assignment operands = new_operands if subgraph.graph is None: @@ -8530,7 +8548,9 @@ class Conditional(ExternKernel): operands: list[Union[TensorBox, ShapeAsConstantBuffer]], ) -> Sequence[IRNode]: """Create a Sequence of IRNodes from a conditional statement (see .lowering.cond)""" + # pyrefly: ignore # bad-assignment predicate = cls.realize_input(predicate) + # pyrefly: ignore # bad-assignment operands = [cls.realize_input(x) for x in operands] fx_operands: Argument = V.graph.current_node.args[-1] @@ -9325,6 +9345,7 @@ class _WaitKernel(_CollectiveKernel): # Case 1 if isinstance(coll, _CollectiveKernel): _, idx = inp.indices[0] + # pyrefly: ignore # bad-return return [coll.inputs[idx]] # Case 2 return [] diff --git a/torch/_inductor/kernel/flex/common.py b/torch/_inductor/kernel/flex/common.py index d4668fe95015..3cd3056a7600 100644 --- a/torch/_inductor/kernel/flex/common.py +++ b/torch/_inductor/kernel/flex/common.py @@ -90,13 +90,16 @@ def get_fwd_subgraph_outputs( subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults ) -> list[Optional[ComputedBuffer]]: subgraph_buffer = ( + # pyrefly: ignore # bad-assignment subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] ) mask_graph_buffer = ( + # pyrefly: ignore # bad-assignment mask_graph_buffer if isinstance(mask_graph_buffer, Sequence) else [mask_graph_buffer] ) + # pyrefly: ignore # not-iterable return [*subgraph_buffer, *mask_graph_buffer] diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index cf07ab944d02..e35bc9b08e5f 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -821,6 +821,7 @@ def flex_attention_backward(*args, **kwargs): **cur_kernel_options, ) inputs_for_autotuning = ( + # pyrefly: ignore # unsupported-operation [ query, key, @@ -891,9 +892,11 @@ def get_bwd_subgraph_outputs( joint_outputs: JointOutputResult, ) -> list[Optional[Union[ComputedBuffer, TensorBox]]]: subgraph_buffer = ( + # pyrefly: ignore # bad-assignment subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] ) mask_graph_buffer = ( + # pyrefly: ignore # bad-assignment mask_graph_buffer if isinstance(mask_graph_buffer, Sequence) else [mask_graph_buffer] @@ -905,4 +908,5 @@ def get_bwd_subgraph_outputs( *joint_outputs.mutated_grads, ] + # pyrefly: ignore # not-iterable return [*subgraph_buffer, *mask_graph_buffer, *joint_output_buffers] diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py index e53a1788058f..4374a93e8d0b 100644 --- a/torch/_inductor/kernel/flex/flex_decoding.py +++ b/torch/_inductor/kernel/flex/flex_decoding.py @@ -367,6 +367,7 @@ def create_flex_decoding_kernel(*args, **kwargs): ] inputs_for_flex_decoding = ( + # pyrefly: ignore # unsupported-operation [ query, key, diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index f61a2852410e..5fd7ab4223ea 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -693,10 +693,12 @@ def _tuned_grouped_mm_common( if len(m2_size) == 2: m, k1 = m1_size k2, _ = m2_size + # pyrefly: ignore # missing-attribute g = offs.get_size()[0] V.graph.sizevars.check_equals(k1, k2) a_is_2d, b_is_2d = True, True else: + # pyrefly: ignore # missing-attribute g1 = offs.layout.size[0] m, k1 = m1_size g2, k2, _ = m2_size @@ -705,6 +707,7 @@ def _tuned_grouped_mm_common( a_is_2d, b_is_2d = True, False else: if len(m2_size) == 2: + # pyrefly: ignore # missing-attribute g1 = offs.layout.size[0] g2, m, k1 = m1_size k2, _ = m2_size diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 76bd6fc87243..45e049141a2f 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -52,6 +52,7 @@ class InterpreterShim(torch.fx.Interpreter): self.current_node = None def run_node(self, n: torch.fx.Node) -> Any: + # pyrefly: ignore # bad-assignment self.current_node = n return super().run_node(n) @@ -436,6 +437,7 @@ class LoopBody: if str(old) == str(new): return assert self.indexing is not None + # pyrefly: ignore # bad-assignment self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} def get_index(self, name): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 1ad7976e21c6..7001fe6a66d2 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -160,6 +160,7 @@ def group_foreach_args(arg_pairs: Iterable[Union[tuple[Any, Any], Any]]): break assert device is not None, "foreach op should have at least one tensor arg" if unpack_args: + # pyrefly: ignore # bad-unpacking (args,) = args out[(device, use_foreach)].append((i, args)) return out @@ -262,6 +263,7 @@ def decode_dtype(dtype: int): if not isinstance(dtype, int): return dtype assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP" + # pyrefly: ignore # bad-assignment dtype = DTYPE_ID_LOOKUP[dtype] return dtype @@ -558,7 +560,9 @@ def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=No return inputs if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs): dtype = override_return_dtype or get_promoted_dtype( - *inputs, type_promotion_kind=type_promotion_kind + *inputs, + # pyrefly: ignore # bad-argument-type + type_promotion_kind=type_promotion_kind, ) def const_func(x): @@ -615,7 +619,9 @@ def make_pointwise( inputs = promote_constants(inputs, override_return_dtype) if allow_alpha: if alpha is not None and alpha != 1: + # pyrefly: ignore # bad-assignment inputs = list(inputs) + # pyrefly: ignore # unsupported-operation inputs[-1] = mul(inputs[-1], alpha) else: assert alpha is None @@ -665,12 +671,14 @@ def make_pointwise( if not override_device: device = None for i in inputs: + # pyrefly: ignore # missing-attribute if is_gpu(i.get_device().type): device = i.get_device() break if not device: device = inputs[0].get_device() + # pyrefly: ignore # unbound-name device = override_device or device return Pointwise.create( @@ -725,6 +733,7 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False): outputs[output_ind] = output if ( + # pyrefly: ignore # unbound-name V.graph.has_feature(device, BackendFeature.FOREACH) and use_foreach and realize_outputs @@ -733,6 +742,7 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False): operation_list.append(output.get_operation_name()) if operation_list: + # pyrefly: ignore # unbound-name V.graph.register_operation_list(operation_list) assert all(x is not None for x in outputs) @@ -3072,8 +3082,10 @@ def copy(self, src, non_blocking=False): src = tensor(src, dtype=self.get_dtype(), device=self.get_device()) x = src if self.get_device() != src.get_device(): + # pyrefly: ignore # bad-argument-type x = to_device(x, self.get_device()) if self.get_dtype() != src.get_dtype(): + # pyrefly: ignore # bad-argument-type x = to_dtype(x, self.get_dtype()) if self.get_size() != src.get_size(): @@ -3097,6 +3109,7 @@ def clone_preserve_reinterpret_view(x): reinterpret_view_layouts = [] if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView): x = x.data # unwrap TensorBox + # pyrefly: ignore # bad-assignment while isinstance(x, ir.ReinterpretView): reinterpret_view_layouts.append(x.get_layout()) x = x.data @@ -3175,6 +3188,7 @@ def slice_scatter(x, src, dim=0, start=None, end=None, step=1): dim = _validate_dim(x, dim, 0) dim_size = x.get_size()[dim] + # pyrefly: ignore # bad-argument-type start, end = ir.SliceView.normalize_start_end(x, dim, start, end) src_size = list(x.get_size()) @@ -3497,6 +3511,7 @@ def new_constant(fill_value): assert isinstance(size, (list, tuple)) assert_nyi(not pin_memory, "pin_memory") assert_nyi(layout in (None, torch.strided), f"layout={layout}") + # pyrefly: ignore # bad-argument-type dtype = decode_dtype(dtype) or x.get_dtype() device = device or x.get_device() size = [sympy.Integer(s) for s in size] @@ -3529,6 +3544,7 @@ def empty_strided( assert isinstance(stride, (list, tuple, type(None))) assert_nyi(not pin_memory, "pin_memory") assert_nyi(layout in (None, torch.strided), f"layout={layout}") + # pyrefly: ignore # bad-argument-type dtype = decode_dtype(dtype) or torch.get_default_dtype() device = device or torch.tensor(0.0).device device = decode_device(device) @@ -4187,6 +4203,7 @@ def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = return src_loader(idx) else: # src is a scalar + # pyrefly: ignore # bad-argument-type return ops.constant(src, self.get_dtype()) def backend_reduce_str(reduce): @@ -4540,6 +4557,7 @@ def constant_boundary_condition( ): h = x.get_size()[-dim:] x_loader = x.make_loader() + # pyrefly: ignore # unsupported-operation padding_h = padding or [0] * dim def load(index): @@ -4548,6 +4566,7 @@ def constant_boundary_condition( mask = functools.reduce( ops.and_, + # pyrefly: ignore # no-matching-overload [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)], ) return ( @@ -5429,6 +5448,7 @@ def upsample_nearest2d_backward( inp_h = V.graph.sizevars.guard_int(inp_h) inp_w = V.graph.sizevars.guard_int(inp_w) + # pyrefly: ignore # not-iterable *_batch, out_h, out_w = input_size if inp_h % out_h == 0 and inp_w % out_w == 0: @@ -5461,6 +5481,7 @@ def upsample_nearest2d_backward( device=x.get_device(), dtype=x.get_dtype(), inner_fn=fn, + # pyrefly: ignore # no-matching-overload ranges=list(input_size), ) @@ -6316,6 +6337,7 @@ def pow(a, b): if isinstance(a, Number): if a == 1: return full_like(b, 1) + # pyrefly: ignore # missing-attribute if a == 2 and is_float_dtype(b.get_dtype()): return exp2(b) @@ -6463,9 +6485,12 @@ def div_prim(a, b): # see https://github.com/pytorch/pytorch/issues/157959 if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu": # Replace divide by constant with multiply by reciprocal + # pyrefly: ignore # unbound-name if divisor.value == 0: + # pyrefly: ignore # unbound-name reciprocal = math.copysign(float("inf"), divisor.value) else: + # pyrefly: ignore # unbound-name reciprocal = 1.0 / divisor.value return mul(a, reciprocal) diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 866c22abd069..16c77556e69d 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -83,6 +83,7 @@ def _prepare_convolution_fusion_create( output_size.append(input_size[0]) output_size.append(weight_size[0]) for d in range(2, dim): + # pyrefly: ignore # unsupported-operation dilation_ = dilation[d - 2] if has_dilation else 1 kernel = dilation_ * (weight_size[d] - 1) + 1 output_size_d = (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[ @@ -409,6 +410,7 @@ class ConvolutionBinary(ExternKernelAlloc): ) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) + # pyrefly: ignore # bad-assignment other = cls.require_stride_order(other, req_stride_order) inputs.insert(1, other) constant_args = constant_args + [ @@ -486,6 +488,7 @@ class ConvolutionBinaryInplace(ExternKernelAlloc): ) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) + # pyrefly: ignore # bad-assignment other = cls.require_stride_order(other, req_stride_order) inputs.insert(1, other) constant_args = constant_args + [ @@ -1216,16 +1219,23 @@ class MkldnnRnnLayer(ExternKernelAlloc): batch_first: bool, train: bool, ): + # pyrefly: ignore # bad-assignment x = cls.require_stride1(cls.realize_input(x)) # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. # Make sure x is contiguous in batch_first case. x.freeze_layout() + # pyrefly: ignore # bad-assignment w0 = cls.require_stride1(cls.realize_input(w0)) + # pyrefly: ignore # bad-assignment w1 = cls.require_stride1(cls.realize_input(w1)) + # pyrefly: ignore # bad-assignment w2 = cls.require_stride1(cls.realize_input(w2)) + # pyrefly: ignore # bad-assignment w3 = cls.require_stride1(cls.realize_input(w3)) + # pyrefly: ignore # bad-assignment hx = cls.require_stride1(cls.realize_input(hx)) hx.freeze_layout() + # pyrefly: ignore # bad-assignment cx = cls.require_stride1(cls.realize_input(cx)) cx.freeze_layout() diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index b39092772903..2cf844c9f721 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -142,6 +142,7 @@ def grouped_gemm_lowering( num_gemm = len(w) assert config.max_autotune or config.max_autotune_gemm + # pyrefly: ignore # bad-assignment b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b] choices: list[ChoiceCaller] = [] @@ -176,6 +177,7 @@ def grouped_gemm_lowering( ir.MultiOutput(layout, template_buf, [(list, gemm_idx)]) for gemm_idx in range(num_gemm) ] + # pyrefly: ignore # bad-argument-type template_buf.layout = ir.MultiOutputLayout(device=input_nodes[0].get_device()) template_buf.outputs = return_bufs return_tensors = [ @@ -424,6 +426,7 @@ def register_onednn_fusion_ops(): "epilogue_creator": epilogue_creator, } + # pyrefly: ignore # unsupported-operation kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1] CppGemmTemplate.add_choices( choices, @@ -721,6 +724,7 @@ def register_onednn_fusion_ops(): # If w_zp is None, then it's a dummy tensor created to denote the # absence of a zero point, and thus w is int8 symmetrically quantized. # Moreover, oneDNN qlinear API doesn't accept None value for zp + # pyrefly: ignore # bad-assignment w_zp = V.graph.add_tensor_constant( torch.tensor(0, dtype=torch.int32), name="w_zp" ) @@ -764,7 +768,9 @@ def register_onednn_fusion_ops(): ) = create_int8_compensation( W_tensor, packed_weight, + # pyrefly: ignore # bad-argument-type x_scale, + # pyrefly: ignore # bad-argument-type x_zp, w_scale, ) @@ -823,6 +829,7 @@ def register_onednn_fusion_ops(): ) # Step 2: add Bias if applicable if bias is not None: + # pyrefly: ignore # not-callable _bias = bias_loader(weight_compens_index) nonlocal bias_dtype assert bias_dtype in [torch.float32, torch.bfloat16] @@ -1013,6 +1020,7 @@ def register_onednn_fusion_ops(): ) if w_zp is None: + # pyrefly: ignore # bad-assignment w_zp = V.graph.add_tensor_constant( torch.tensor(0, dtype=torch.int32), name="w_zp" ) @@ -1087,7 +1095,9 @@ def register_onednn_fusion_ops(): ) = create_int8_compensation( W_tensor, packed_weight, + # pyrefly: ignore # bad-argument-type x_scale, + # pyrefly: ignore # bad-argument-type x_zp, w_scale, ) @@ -1147,6 +1157,7 @@ def register_onednn_fusion_ops(): ) # Step 2: add Bias if applicable if bias is not None: + # pyrefly: ignore # not-callable _bias = bias_loader(weight_compens_index) nonlocal bias_dtype assert bias_dtype in [torch.float32, torch.bfloat16] diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 5872764c8984..42be50270995 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -110,6 +110,7 @@ _StrideExprStr: TypeAlias = str # to achieve writing to all values of that dimension of the input tensor def get_expanded_dims(t: torch.Tensor) -> list[int]: if not isinstance(t, torch.Tensor): + # pyrefly: ignore # bad-return return None return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1] diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index bd11d033cadb..7c7884c92dba 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -131,6 +131,7 @@ def load_package( ) return AOTICompiledModel(loader) + # pyrefly: ignore # no-matching-overload path = os.fspath(path) # AOTIModelPackageLoader expects (str, str) loader = torch._C._aoti.AOTIModelPackageLoader( path, model_name, run_single_threaded, num_runners, device_index diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 3bb06df651c7..aaedf37b4eb2 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -270,6 +270,7 @@ class Match: ] ) + # pyrefly: ignore # bad-context-manager with context: if trace_fn is None: trace_fn = functools.partial( @@ -1197,9 +1198,13 @@ class ReplacementPatternEntry(PatternEntry): if graph_name is None: assert isinstance(target, str) _, graph_name = unique_graph_name_with_root( - graph.owning_module, target + # pyrefly: ignore # unbound-name + graph.owning_module, + target, ) + # pyrefly: ignore # unbound-name graph.owning_module.register_module(graph_name, sub_gm) + # pyrefly: ignore # unbound-name getattr_node = graph.get_attr(graph_name) added_replacement_nodes.append(getattr_node) return getattr_node @@ -1498,6 +1503,7 @@ def register_replacement( return search_fn(*args_new[len(args_new) - len(args) :]) try: + # pyrefly: ignore # bad-argument-type specific_graph = trace_fn(search_fn_new, sym_args + args) except RuntimeError as e: log_trace_failure(search_fn, e) @@ -1648,6 +1654,7 @@ def _serialize_pattern( if isinstance(attr, type) and issubclass( attr, (PatternExpr, _TargetExpr) ): + # pyrefly: ignore # bad-argument-type pattern_matcher_imports.append(name) except TypeError: pass @@ -2055,10 +2062,14 @@ def fx_to_pattern( argnum = itertools.count() class Converter(torch.fx.Interpreter): + # pyrefly: ignore # bad-override call_method = _not_implemented + # pyrefly: ignore # bad-override call_module = _not_implemented + # pyrefly: ignore # bad-override get_attr = _not_implemented + # pyrefly: ignore # bad-override def placeholder( self, target: str, # type: ignore[override] @@ -2079,6 +2090,7 @@ def fx_to_pattern( else: return KeywordArg(name) + # pyrefly: ignore # bad-override def call_function( self, target: str, # type: ignore[override] @@ -2113,6 +2125,7 @@ def fx_to_pattern( assert isinstance(args, Collection) assert len(rv) == len(args) for r, arg in zip(rv, args): + # pyrefly: ignore # missing-attribute r.users = len(arg.users) else: rv.users = len(n.users) @@ -2187,7 +2200,10 @@ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.Graph torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size") ) GraphPatternEntry( - pattern=pattern, handler=pointless_view, extra_check=_return_true + pattern=pattern, + handler=pointless_view, + extra_check=_return_true, + # pyrefly: ignore # bad-argument-type ).register(matcher_pass.patterns) matcher_pass.apply(gm.graph) @@ -2277,6 +2293,7 @@ def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: new_node.node.name = self.new_graph._graph_namespace.create_name( old_node.name, None ) + # pyrefly: ignore # bad-return return new_node return CopyGraph(input_graph).transform() diff --git a/torch/_inductor/quantized_lowerings.py b/torch/_inductor/quantized_lowerings.py index c7628314a85c..5e192579bbec 100644 --- a/torch/_inductor/quantized_lowerings.py +++ b/torch/_inductor/quantized_lowerings.py @@ -137,6 +137,7 @@ def register_woq_mm_ops() -> None: ) and mat2.get_layout().is_contiguous() ): + # pyrefly: ignore # bad-specialization, missing-attribute, not-a-type CppWoqInt4GemmTemplate[qGroupSize].add_choices( choices, aten_layout, diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 1304ce79b86e..8b143520808f 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -160,6 +160,7 @@ class RemoteCache(Generic[_T]): self.backend = override_cls() else: self.backend = backend + # pyrefly: ignore # invalid-type-var self.serde = serde # See if the cache contains `key`. Returns `None` if the value is not @@ -245,6 +246,7 @@ class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]): A Redis implementation of a remote/distributed cache. """ + # pyrefly: ignore # missing-attribute _redis: Optional[redis.Redis] = None def __init__(self, cache_id: str) -> None: @@ -267,7 +269,9 @@ class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]): return None try: + # pyrefly: ignore # missing-attribute value = self._redis.get(key) + # pyrefly: ignore # missing-attribute except redis.exceptions.ConnectionError: # Redis is lazy and doesn't actually attempt to connect until the # first use. Mark is as unavailable now. @@ -285,7 +289,9 @@ class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]): return try: + # pyrefly: ignore # missing-attribute self._redis.set(key, data) + # pyrefly: ignore # missing-attribute except redis.exceptions.ConnectionError: # Redis is lazy and doesn't actually attempt to connect until the # first use. Mark is as unavailable now. diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 4c109de28641..ffa9bff9d5ee 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1019,6 +1019,7 @@ def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float if mm_fn is None: return None bench_fn = mm_fn + # pyrefly: ignore # unbound-name args_kwargs_fn = lambda: snode_args_kwargs(snode) # noqa: E731 else: return None @@ -1295,6 +1296,7 @@ class SchedulerNode(BaseSchedulerNode): new_order = self_dep.decide_loop_order_to_match(other_dep) if new_order: + # pyrefly: ignore # bad-assignment metrics.num_loop_reordering += 1 loop_ordering_log.debug( "Reorder loops for %s with order %s", self.get_name(), new_order @@ -1591,6 +1593,7 @@ class FusedSchedulerNode(BaseSchedulerNode): self.get_name(), ) return False + # pyrefly: ignore # bad-assignment metrics.num_loop_reordering += 1 loop_ordering_log.debug( "Reorder loops for fused node %s with order %s", self.get_name(), new_order @@ -2322,6 +2325,7 @@ class Scheduler: self.name_to_fused_node = {n.get_name(): n for n in self.nodes} self.compute_ancestors() + # pyrefly: ignore # bad-assignment metrics.ir_nodes_pre_fusion += len(self.nodes) from torch._inductor.debug import log_ir_post_fusion, log_ir_pre_fusion @@ -2549,6 +2553,7 @@ class Scheduler: ] return DedupList(new_items, new_membership) + # pyrefly: ignore # not-a-type name_to_users: defaultdict[str, DedupList[NodeUser]] = collections.defaultdict( DedupList ) @@ -2585,12 +2590,14 @@ class Scheduler: else: name_to_users[buf1_name] = name_to_users[buf2_name] + # pyrefly: ignore # not-a-type def rename(n: str) -> str: if n in self.mutation_renames: return rename(self.mutation_renames[n]) return n def add_user( + # pyrefly: ignore # not-a-type used_by_name: str, user_node: Union[BaseSchedulerNode, OutputNode], can_inplace: bool = False, @@ -2600,6 +2607,7 @@ class Scheduler: NodeUser(user_node, can_inplace, is_weak) ) + # pyrefly: ignore # not-a-type unbacked_symbol_to_origin_node: dict[sympy.Symbol, Optional[str]] = {} # NB: None means that the dependency is on an input. Don't actually @@ -2658,6 +2666,7 @@ class Scheduler: and (dep := next(iter(node.read_writes.writes))) and isinstance(dep, MemoryDep) ): + # pyrefly: ignore # unbound-name node_mode = dep.mode else: node_mode = None @@ -3429,9 +3438,12 @@ class Scheduler: str(e), ) continue + # pyrefly: ignore # missing-attribute with multi_node.swap_as_triton_caller(choice): ms_fused, path = self.benchmark_codegened_module( - mod_fused, device + mod_fused, + # pyrefly: ignore # bad-argument-type + device, ) new_timings[choice] = ms_fused if ms_fused < min_ms_fused: @@ -3443,12 +3455,15 @@ class Scheduler: if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None: if config.multi_kernel_hints: hint_override_best_fusion_choice[None] = ms_fused_choice + # pyrefly: ignore # missing-attribute multi_node.finalize_as_triton_callers( hint_override_best_fusion_choice ) else: + # pyrefly: ignore # missing-attribute multi_node.finalize_as_triton_caller(ms_fused_choice) + # pyrefly: ignore # missing-attribute multi_node._choice_timings[None] = new_timings return True else: @@ -3478,21 +3493,27 @@ class Scheduler: fut.result() ms1, path1 = self.benchmark_codegened_module( - future_and_mod_l1[1], device + future_and_mod_l1[1], + # pyrefly: ignore # bad-argument-type + device, ) if math.isinf(ms1): why("register spilling of the first kernel") return False ms2, path2 = self.benchmark_codegened_module( - future_and_mod_l2[1], device + future_and_mod_l2[1], + # pyrefly: ignore # bad-argument-type + device, ) if math.isinf(ms2): why("register spilling of the second kernel") return False ms_fused, path_fused = self.benchmark_codegened_module( - future_and_mod_l1_fused[1], device + future_and_mod_l1_fused[1], + # pyrefly: ignore # bad-argument-type + device, ) if math.isinf(ms_fused): why("register spilling of the fused kernel") @@ -4323,6 +4344,7 @@ class Scheduler: if config.expand_dimension_for_pointwise_nodes and ( expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2) ): + # pyrefly: ignore # unbound-name (expand_dim, smaller_node, expand_size) = expand_analysis smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size) shared_data_score = self.score_fusion_memory(node1, node2) @@ -4633,6 +4655,7 @@ class Scheduler: device.type == "cuda" and (device_props := torch.cuda.get_device_properties(device)).major < 7 ): + # pyrefly: ignore # unbound-name raise GPUTooOldForTriton(device_props, inspect.currentframe()) elif is_gpu(device.type) and not device.type == "mps": raise TritonMissing(inspect.currentframe()) @@ -4930,6 +4953,7 @@ class Scheduler: if isinstance(buf.node, ir.MutationOutput) and ( real_name := self.mutation_real_name.get(buf_name, None) ): + # pyrefly: ignore # unbound-name return is_none_layout(real_name) return True @@ -5028,6 +5052,7 @@ class Scheduler: signatures.append(partition_signature) unmet_output_names = partition_input_names.union( + # pyrefly: ignore # unsupported-operation unmet_output_names - returned_output_names ) @@ -5410,6 +5435,7 @@ class Scheduler: self.current_device = self.default_device_context + # pyrefly: ignore # unbound-name if self.default_device_context and config.triton.autotune_at_compile_time: V.graph.wrapper_code.write_get_raw_stream_header() @@ -5453,6 +5479,7 @@ class Scheduler: prologue, template_node, epilogue = node.get_prologue_template_epilogue( list(node.get_nodes()) ) + # pyrefly: ignore # unbound-name self.get_backend(device).codegen_template( template_node, epilogue, prologue ) @@ -5461,6 +5488,7 @@ class Scheduler: self.codegen_extern_call(node) elif node.is_foreach(): node = typing.cast(ForeachKernelSchedulerNode, node) + # pyrefly: ignore # unbound-name backend_ = self.get_backend(device) from .codegen.cuda_combined_scheduling import CUDACombinedScheduling from .codegen.simd import SIMDScheduling @@ -5471,12 +5499,15 @@ class Scheduler: raise AssertionError(f"{type(self)=}") backend.codegen_combo_kernel(node) elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): + # pyrefly: ignore # unbound-name self.get_backend(device).codegen_node(node) else: assert isinstance(node, NopKernelSchedulerNode) node.mark_run() + # pyrefly: ignore # unbound-name if config.triton.debug_sync_kernel: + # pyrefly: ignore # unbound-name self.get_backend(device).codegen_sync() self.available_buffer_names.update(node.get_buffer_names()) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index c71adb3c9121..24de4ae373af 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -332,6 +332,7 @@ class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined] """Convert index variable to symbolic form.""" return sympy_index_symbol(str(index_var)) + # pyrefly: ignore # bad-override def store( self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None ) -> str: @@ -435,6 +436,7 @@ class TritonTemplateKernel(TritonKernel): # for templates with fixed epilogues self.prefix_args = prefix_args self.suffix_args = suffix_args + # pyrefly: ignore # invalid-type-var self.epilogue_fn = epilogue_fn self.render_hooks = {} # type: ignore[var-annotated] self.triton_meta: Optional[dict[str, object]] = None @@ -552,6 +554,7 @@ class TritonTemplateKernel(TritonKernel): context = ( contextlib.nullcontext if not self.ops_handler + # pyrefly: ignore # not-callable else lambda: V.set_ops_handler(self.ops_handler(V.get_ops_handler())) ) with context(): # type: ignore[operator] @@ -990,6 +993,7 @@ class TritonTemplateKernel(TritonKernel): f"{output_name} = {value_str}.broadcast_to(xindex.shape)" ) + # pyrefly: ignore # bad-assignment self.ops_handler = StoreOutputSubstitution input_node = self.named_input_nodes[input_name] @@ -1193,6 +1197,7 @@ class TritonTemplateKernel(TritonKernel): val_shape[i], i, len(index_order), + # pyrefly: ignore # missing-argument block_name=range_tree.symt.name, ) ) @@ -1206,6 +1211,7 @@ class TritonTemplateKernel(TritonKernel): ) # Update the val_shape information to use consistent naming # after the remapping. + # pyrefly: ignore # missing-argument val_shape_copy[i] = range_tree.symt.name # Reverse the index symbols because TMA is indexed # as (x, y) whereas the variables will naturally be indexed @@ -1283,6 +1289,7 @@ class TritonTemplateKernel(TritonKernel): if output_index == contiguous_index: output_index = sympy.Symbol("xindex", integer=True) + # pyrefly: ignore # bad-assignment self.template_out_shape = val_shape if val_shape else val acc_dtype = ( triton_type_to_torch(self.meta["ACC_TYPE"]) @@ -1899,6 +1906,7 @@ class TritonTemplate(KernelTemplate): extra, input_call_args, prologue_supported_inputs, + # pyrefly: ignore # bad-argument-type kernel_args_sizevars_keys, kernel_options, ) @@ -2462,6 +2470,7 @@ class DataProcessorTemplateWrapper: self._postprocessor = lambda x: x assert "input_nodes" in kwargs assert "layout" in kwargs + # pyrefly: ignore # not-callable kwargs["input_nodes"], kwargs["layout"] = preprocessor( kwargs["input_nodes"], kwargs["layout"] ) @@ -2633,6 +2642,7 @@ class AlgorithmSelectorCache(PersistentCache): choice for choice in choices if isinstance(choice, ExternKernelChoice) ] if len(externs) > 0: + # pyrefly: ignore # bad-return return externs[0] else: return choices[0] @@ -3130,7 +3140,9 @@ class AlgorithmSelectorCache(PersistentCache): # de-duplicate args unique_example_inputs = { x.get_name(): input_gen_fns.get( - i, lambda x: cls.benchmark_example_value(x, hint_override=hint_override) + i, + lambda x: cls.benchmark_example_value(x, hint_override=hint_override), + # pyrefly: ignore # bad-argument-type )(x) for i, x in enumerate(input_nodes) } @@ -3617,8 +3629,10 @@ class AlgorithmSelectorCache(PersistentCache): ), node.get_device(), node.get_dtype(), + # pyrefly: ignore # missing-attribute node.layout.offset, V.graph.sizevars.size_hints( + # pyrefly: ignore # bad-argument-type V.graph.get_allocation_size(node), fallback=config.unbacked_symint_fallback, hint_override=hint_override, diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 209b6f831e5b..ed2b44fc3bca 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -180,6 +180,7 @@ class SizeVarAllocator: def statically_known(expr): evaluated = self.shape_env._maybe_evaluate_static( expr, + # pyrefly: ignore # bad-argument-type axioms=axioms, var_to_range=var_to_range_tuple, ) diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index d88dfda82379..84d7021688bc 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -1638,6 +1638,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics): ) # Build options dict + # pyrefly: ignore # no-matching-overload options_dict = dict( EVEN_K=even_k_symbolic, USE_FAST_ACCUM=False, # Option for _scaled_mm @@ -1720,6 +1721,7 @@ class TMAWorkspaceMixin(MMTemplateConfigMixin): ) return kwargs + # pyrefly: ignore # bad-override def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: """ TMA specific filtering, as num_warps=2 not safe for TMA @@ -1944,6 +1946,7 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin): This inherits from BaseScaledMMConfigMixin and adds TMA-specific options. """ + # pyrefly: ignore # bad-override def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: """ TMA specific filtering: @@ -1984,6 +1987,7 @@ class ScaledBlackwellTMAConfigMixin( This inherits from ScaledMMConfigMixin, which inherits the scale_mm_epilogue, and adds TMA-specific options. """ + # pyrefly: ignore # bad-override def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: """ Warp specialization-specific filtering (BlackwellTMATemplateConfigMixin) @@ -2116,6 +2120,7 @@ class CUDAScaledMMTemplateConfigHeuristic(ScaledMMConfigMixin, CUDAConfigHeurist # Override mm_configs to use scaled_mm_configs self.mm_configs = self.scaled_mm_configs + # pyrefly: ignore # bad-override def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: configs = [c for c in configs if c.block_k >= 32] return super()._filter_configs(configs) diff --git a/torch/_inductor/test_operators.py b/torch/_inductor/test_operators.py index d3d2705f8c78..fffc71db1358 100644 --- a/torch/_inductor/test_operators.py +++ b/torch/_inductor/test_operators.py @@ -15,6 +15,7 @@ for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): class Realize(Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx: object, x: Tensor) -> Tensor: return torch.ops._inductor_test.realize(x) diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index d7b64ba1b867..3142f97f8c40 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -126,6 +126,7 @@ def solve_for_tiling(expr: sympy.Expr) -> Optional[sympy.Expr]: # For the purposes of tiling/coalesced access, approximate ModularIndexing and FloorDiv # then check later + # pyrefly: ignore # missing-attribute eq_1_expr_simplified = eq_1_expr.replace(ModularIndexing, indexing_div_rep).replace( FloorDiv, indexing_div_rep ) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 93ac87d4b2ce..5291a0f8f9ab 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -678,6 +678,7 @@ def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]: """ Variant of cache_on_self for properties. The only difference is the type signature. """ + # pyrefly: ignore # bad-argument-type return cache_on_self(fn) @@ -690,6 +691,7 @@ def aggregate_origins( return functools.reduce( operator.or_, [ + # pyrefly: ignore # missing-attribute node.node.origins for node in node_schedule if hasattr(node, "node") and node.node @@ -1166,6 +1168,7 @@ def unload_xpu_triton_pyds() -> None: result, torch._inductor.runtime.triton_heuristics.TritonCompileResult, ): + # pyrefly: ignore # missing-attribute result.kernel.run.mod.__del__() del sys.modules[module_name] @@ -1439,6 +1442,7 @@ class IndentedBuffer: ) -> None: if isinstance(other_code, IndentedBuffer): dedent = float("inf") + # pyrefly: ignore # bad-assignment for line in other_code._lines: if not isinstance(line, LineContext) and line: dedent = min(dedent, len(line) - len(line.lstrip())) @@ -2208,6 +2212,7 @@ def run_and_get_code( def run_and_get_kernels( fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs ) -> tuple[_T, list[str]]: + # pyrefly: ignore # bad-argument-type result, source_codes = run_and_get_code(fn, *args, **kwargs) kernels = [] for code in source_codes: @@ -2268,6 +2273,7 @@ def get_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> list[str def get_triton_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> str: + # pyrefly: ignore # bad-argument-type source_codes = get_code(fn, *args, **kwargs) # Can have two outputs if backwards was eagerly compiled assert 1 <= len(source_codes) <= 2, ( @@ -2279,6 +2285,7 @@ def get_triton_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> s def run_and_get_triton_code( fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs ) -> str: + # pyrefly: ignore # bad-argument-type _, source_codes = run_and_get_code(fn, *args, **kwargs) # Can have two outputs if backwards was eagerly compiled assert 1 <= len(source_codes) <= 2, ( @@ -3674,6 +3681,7 @@ def maybe_log_cudagraph_partition( and (fx_node := ir_node.get_origin_node()) and (stack_trace := fx_node.meta.get("stack_trace", None)) ): + # pyrefly: ignore # unbound-name warning_msg = f"{warning_msg}. Found from : \n {stack_trace}" perf_hint_log.warning(warning_msg) diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 9a527471c8cc..f8430064917e 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -144,6 +144,7 @@ def benchmark_all_kernels( launcher = triton_kernel.launchers[0] print( get_info_str( + # pyrefly: ignore # bad-argument-type ms, launcher.n_regs, launcher.n_spills, diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 959990fa5e8e..d895ab377e7c 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -192,7 +192,6 @@ class cuBLASModule: value, "allow_fp16_reduced_precision_reduction" ) return torch._C._set_cublas_allow_fp16_reduced_precision_reduction( - # pyrefly: ignore # bad-argument-count allow_reduced_precision, # pyrefly: ignore # bad-argument-count allow_splitk, @@ -202,7 +201,6 @@ class cuBLASModule: value, "allow_bf16_reduced_precision_reduction" ) return torch._C._set_cublas_allow_bf16_reduced_precision_reduction( - # pyrefly: ignore # bad-argument-count allow_reduced_precision, # pyrefly: ignore # bad-argument-count allow_splitk, diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index f8b5a7a75b2f..8cc4c7993417 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -135,7 +135,7 @@ if is_available(): # this. # pyrefly: ignore # deprecated from .distributed_c10d import * # noqa: F403 - from .distributed_c10d import ( # pyrefly: ignore # deprecated + from .distributed_c10d import ( _all_gather_base, _coalescing_manager, _CoalescingManager, diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index a0aff568f445..5dd56fc006c4 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -1176,7 +1176,7 @@ def all_gather_inplace( return tensor_list -from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated +from torch.distributed.distributed_c10d import ( _all_gather_base as legacy_all_gather_base, _reduce_scatter_base as legacy_reduce_scatter_base, all_gather as legacy_all_gather, diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index 00b63efccdb7..a1febff0a6fc 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -80,7 +80,6 @@ def _compress_hook( if torch.compiler.is_compiling(): grad = dist._functional_collectives.all_reduce( - # pyrefly: ignore # bad-argument-type compressed_tensor, "sum", # pyrefly: ignore # bad-argument-type diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index c1ad28206a28..a98a9647aeb8 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -2002,7 +2002,6 @@ def _new_process_group_helper( if not is_gloo_available(): raise RuntimeError("Distributed package doesn't have Gloo built in") backend_class = ProcessGroupGloo( - # pyrefly: ignore # bad-argument-type backend_prefix_store, group_rank, group_size, @@ -2048,7 +2047,6 @@ def _new_process_group_helper( # RuntimeError if is_ucc_available() returns false. backend_class = ProcessGroupUCC( - # pyrefly: ignore # bad-argument-type backend_prefix_store, group_rank, group_size, diff --git a/torch/distributed/elastic/metrics/__init__.py b/torch/distributed/elastic/metrics/__init__.py index 40a429fb2a43..b07671fbac9d 100644 --- a/torch/distributed/elastic/metrics/__init__.py +++ b/torch/distributed/elastic/metrics/__init__.py @@ -142,7 +142,7 @@ Now all metrics in the group ``my_app`` will be printed to stdout as: from typing import Optional -from .api import ( # noqa: F401; pyrefly: ignore # deprecated; pyrefly: ignore # deprecated +from .api import ( # noqa: F401 configure, ConsoleMetricHandler, get_elapsed_time_ms, diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 21daa1c6c99b..793c843e9920 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -905,7 +905,6 @@ def _materialize_meta_module( # As a contract to the user, only call `reset_parameters()` if # the module has directly managed parameters/buffers module_state_iter = itertools.chain( - # pyrefly: ignore # bad-argument-type module.parameters(recurse=False), # pyrefly: ignore # bad-argument-type module.buffers(recurse=False), diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 337dafa5a8a5..0883db086f33 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -189,7 +189,6 @@ class OpDispatcher: local_tensor_args = ( pytree.tree_unflatten( - # pyrefly: ignore # bad-argument-type cast(list[object], op_info.local_args), # pyrefly: ignore # bad-argument-type op_info.args_tree_spec, @@ -364,7 +363,6 @@ class OpDispatcher: with redistribute_context: resharded_local_tensor = redistribute_local_tensor( - # pyrefly: ignore # bad-argument-type local_tensor, arg_spec, # pyrefly: ignore # bad-argument-type @@ -438,7 +436,6 @@ class OpDispatcher: op_call, args_list ) kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor( - # pyrefly: ignore # bad-argument-type op_call, v, # pyrefly: ignore # bad-argument-type diff --git a/torch/distributed/tensor/examples/torchrec_sharding_example.py b/torch/distributed/tensor/examples/torchrec_sharding_example.py index f788c0c11f55..9647b4bb93ef 100644 --- a/torch/distributed/tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/tensor/examples/torchrec_sharding_example.py @@ -90,7 +90,6 @@ class LocalShardsWrapper(torch.Tensor): # TODO: we shall continually extend this function to support more ops if needed if func in supported_ops: res_shards_list = [ - # pyrefly: ignore # index-error func(shard, *args[1:], **kwargs) # pyrefly: ignore # index-error for shard in args[0].shards diff --git a/torch/nn/utils/__init__.py b/torch/nn/utils/__init__.py index ed9a83b13389..84145da93f7b 100644 --- a/torch/nn/utils/__init__.py +++ b/torch/nn/utils/__init__.py @@ -1,5 +1,5 @@ from . import parametrizations, parametrize, rnn, stateless -from .clip_grad import ( # pyrefly: ignore # deprecated +from .clip_grad import ( _clip_grads_with_norm_ as clip_grads_with_norm_, _get_total_norm as get_total_norm, clip_grad_norm,