diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d4c05a092c1d..729b11157485 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -118,9 +118,9 @@ jobs: CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" echo "Running all other linters" if [ "$CHANGED_FILES" = '*' ]; then - ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh + ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh else - ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh + ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh fi quick-checks: diff --git a/.lintrunner.toml b/.lintrunner.toml index 57f82a1699c3..411e4d2c215b 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -209,6 +209,46 @@ command = [ '@{{PATHSFILE}}' ] + +[[linter]] +code = 'PYREFLY' +include_patterns = [ + 'torch/**/*.py', + 'torch/**/*.pyi', + 'torchgen/**/*.py', + 'torchgen/**/*.pyi', + 'functorch/**/*.py', + 'functorch/**/*.pyi', +] +exclude_patterns = [] +command = [ + 'python3', + 'tools/linter/adapters/pyrefly_linter.py', + '--config=pyrefly.toml', +] +init_command = [ + 'python3', + 'tools/linter/adapters/pip_init.py', + '--dry-run={{DRYRUN}}', + 'numpy==2.1.0 ; python_version >= "3.12"', + 'expecttest==0.3.0', + 'pyrefly==0.36.2', + 'sympy==1.13.3', + 'types-requests==2.27.25', + 'types-pyyaml==6.0.2', + 'types-tabulate==0.8.8', + 'types-protobuf==5.29.1.20250403', + 'types-setuptools==79.0.0.20250422', + 'types-jinja2==2.11.9', + 'types-colorama==0.4.6', + 'filelock==3.18.0', + 'junitparser==2.1.1', + 'rich==14.1.0', + 'optree==0.17.0', + 'types-openpyxl==3.1.5.20250919', + 'types-python-dateutil==2.9.0.20251008' +] + [[linter]] code = 'CLANGTIDY' include_patterns = [ diff --git a/pyrefly.toml b/pyrefly.toml index b204f0819ff2..b643be2265e7 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -1,5 +1,7 @@ # A Pyrefly configuration for PyTorch # Based on https://github.com/pytorch/pytorch/blob/main/mypy.ini +python-version = "3.12" + project-includes = [ "torch", "caffe2", @@ -36,6 +38,7 @@ project-excludes = [ "torch/nn/modules/rnn.py", # only remove when parsing errors are fixed "torch/_inductor/codecache.py", "torch/distributed/elastic/metrics/__init__.py", + "torch/_inductor/fx_passes/bucketing.py", # ==== "benchmarks/instruction_counts/main.py", "benchmarks/instruction_counts/definitions/setup.py", diff --git a/tools/linter/adapters/pyrefly_linter.py b/tools/linter/adapters/pyrefly_linter.py new file mode 100644 index 000000000000..77ed9c681e52 --- /dev/null +++ b/tools/linter/adapters/pyrefly_linter.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import argparse +import json +import logging +import os +import re +import subprocess +import sys +import time +from enum import Enum +from typing import NamedTuple + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: str | None + line: int | None + char: int | None + code: str + severity: LintSeverity + name: str + original: str | None + replacement: str | None + description: str | None + + +# Note: This regex pattern is kept for reference but not used for pyrefly JSON parsing +RESULTS_RE: re.Pattern[str] = re.compile( + r"""(?mx) + ^ + (?P.*?): + (?P\d+): + (?:(?P-?\d+):)? + \s(?P\S+?):? + \s(?P.*) + \s(?P\[.*\]) + $ + """ +) + +# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR +INTERNAL_ERROR_RE: re.Pattern[str] = re.compile( + r"""(?mx) + ^ + (?P.*?): + (?P\d+): + \s(?P\S+?):? + \s(?PINTERNAL\sERROR.*) + $ + """ +) + + +def run_command( + args: list[str], + *, + extra_env: dict[str, str] | None, + retries: int, +) -> subprocess.CompletedProcess[bytes]: + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + return subprocess.run( + args, + capture_output=True, + ) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +# Severity mapping (currently only used for stderr internal errors) +# Pyrefly JSON output doesn't include severity, so all errors default to ERROR +severities = { + "error": LintSeverity.ERROR, + "note": LintSeverity.ADVICE, +} + + +def check_pyrefly_installed(code: str) -> list[LintMessage]: + cmd = ["pyrefly", "--version"] + try: + subprocess.run(cmd, check=True, capture_output=True) + return [] + except subprocess.CalledProcessError as e: + msg = e.stderr.decode(errors="replace") + return [ + LintMessage( + path=None, + line=None, + char=None, + code=code, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=f"Could not run '{' '.join(cmd)}': {msg}", + ) + ] + + +def in_github_actions() -> bool: + return bool(os.getenv("GITHUB_ACTIONS")) + + +def check_files( + code: str, + config: str, +) -> list[LintMessage]: + try: + pyrefly_commands = [ + "pyrefly", + "check", + "--config", + config, + "--output-format=json", + ] + proc = run_command( + [*pyrefly_commands], + extra_env={}, + retries=0, + ) + except OSError as err: + return [ + LintMessage( + path=None, + line=None, + char=None, + code=code, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=(f"Failed due to {err.__class__.__name__}:\n{err}"), + ) + ] + stdout = str(proc.stdout, "utf-8").strip() + stderr = str(proc.stderr, "utf-8").strip() + if proc.returncode not in (0, 1): + return [ + LintMessage( + path=None, + line=None, + char=None, + code=code, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=stderr, + ) + ] + + # Parse JSON output from pyrefly + try: + if stdout: + result = json.loads(stdout) + errors = result.get("errors", []) + else: + errors = [] + # For now filter out deprecated warnings and only report type errors as warnings + # until we remove mypy + errors = [error for error in errors if error["name"] != "deprecated"] + rc = [ + LintMessage( + path=error["path"], + name=error["name"], + description=error.get( + "description", error.get("concise_description", "") + ), + line=error["line"], + char=error["column"], + code=code, + severity=LintSeverity.ADVICE, + # uncomment and replace when we switch to pyrefly + # severity=LintSeverity.ADVICE if error["name"] == "deprecated" else LintSeverity.ERROR, + original=None, + replacement=None, + ) + for error in errors + ] + except (json.JSONDecodeError, KeyError, TypeError) as e: + return [ + LintMessage( + path=None, + line=None, + char=None, + code=code, + severity=LintSeverity.ERROR, + name="json-parse-error", + original=None, + replacement=None, + description=f"Failed to parse pyrefly JSON output: {e}", + ) + ] + + # Still check stderr for internal errors + rc += [ + LintMessage( + path=match["file"], + name="INTERNAL ERROR", + description=match["message"], + line=int(match["line"]), + char=None, + code=code, + severity=severities.get(match["severity"], LintSeverity.ERROR), + original=None, + replacement=None, + ) + for match in INTERNAL_ERROR_RE.finditer(stderr) + ] + return rc + + +def main() -> None: + parser = argparse.ArgumentParser( + description="pyrefly wrapper linter.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--code", + default="PYREFLY", + help="the code this lint should report as", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "--config", + required=True, + help="path to an mypy .ini config file", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.INFO, + stream=sys.stderr, + ) + + lint_messages = check_pyrefly_installed(args.code) + check_files( + args.code, args.config + ) + for lint_message in lint_messages: + print(json.dumps(lint_message._asdict()), flush=True) + + +if __name__ == "__main__": + main() diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index 584a6d376bd3..1c32239bfaab 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -76,6 +76,7 @@ class StreamVariable(VariableTracker): super().__init__(**kwargs) self.proxy = proxy self.value = value + # pyrefly: ignore # read-only self.device = device def python_type(self) -> type: diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 4fc9d8c2e79d..6d5b759ac05b 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -1492,6 +1492,7 @@ def _aot_stage2a_partition( # apply joint_gm callback here if callable(torch._functorch.config.joint_custom_pass): + # pyrefly: ignore # bad-assignment fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs) static_lifetime_input_indices = fw_metadata.static_input_indices @@ -1761,6 +1762,7 @@ def _aot_stage2b_bw_compile( # tensor which is wrong. ph_size = ph_arg.size() + # pyrefly: ignore # bad-argument-type if len(ph_size) == 0 and len(real_stride) > 0: # Fix for 0-dimensional tensors: When a tensor becomes 0-d # (e.g., via squeeze), its stride should be () not (1,). diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 906a38e7b7d5..46ffe463c94b 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -628,6 +628,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None: position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs) ] # add the scale nodes to the output find the first sym_node in the output + # pyrefly: ignore # bad-argument-type idx = find_first_sym_node(output_updated_args) scale_nodes = tensor_scale_nodes + sym_scale_nodes if scale_nodes: diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 84d6bc5a1950..a0f213a1e496 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -86,7 +86,7 @@ def bucket_all_gather( mode: BucketMode = "default", ) -> None: if bucket_cap_mb_by_bucket_idx is None: - from torch._inductor.fx_passes.bucketing import ( + from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute bucket_cap_mb_by_bucket_idx_default, ) @@ -103,7 +103,7 @@ def bucket_reduce_scatter( mode: BucketMode = "default", ) -> None: if bucket_cap_mb_by_bucket_idx is None: - from torch._inductor.fx_passes.bucketing import ( + from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute bucket_cap_mb_by_bucket_idx_default, ) diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py index c3eed5660479..5aad94b781e9 100644 --- a/torch/_inductor/fx_passes/freezing_patterns.py +++ b/torch/_inductor/fx_passes/freezing_patterns.py @@ -209,6 +209,7 @@ def addmm_patterns_init(): # pyrefly: ignore # bad-argument-type int8_woq_fusion_replacement, [val(), val(), val(), val(), scale(), scale(), scale()], + # pyrefly: ignore # bad-argument-type fwd_only, # pyrefly: ignore # bad-argument-type pass_patterns[0], @@ -230,6 +231,7 @@ def addmm_patterns_init(): # pyrefly: ignore # bad-argument-type matmul_replacement, [val(), val(), val(), val()], + # pyrefly: ignore # bad-argument-type fwd_only, # pyrefly: ignore # bad-argument-type pass_patterns[0], @@ -251,6 +253,7 @@ def addmm_patterns_init(): # pyrefly: ignore # bad-argument-type matmul_replacement_two, [val(), val(), val()], + # pyrefly: ignore # bad-argument-type fwd_only, # pyrefly: ignore # bad-argument-type pass_patterns[0], @@ -276,6 +279,7 @@ def addmm_patterns_init(): # pyrefly: ignore # bad-argument-type addmm_fuse_replacement_second, [val() for _ in range(7)], + # pyrefly: ignore # bad-argument-type fwd_only, # pyrefly: ignore # bad-argument-type pass_patterns[0], diff --git a/torch/_inductor/fx_passes/misc_patterns.py b/torch/_inductor/fx_passes/misc_patterns.py index 538a2ca2c43b..7b157bf03a91 100644 --- a/torch/_inductor/fx_passes/misc_patterns.py +++ b/torch/_inductor/fx_passes/misc_patterns.py @@ -49,6 +49,7 @@ def _misc_patterns_init(): # pyrefly: ignore # bad-argument-type randperm_index_add_replacement, [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], + # pyrefly: ignore # bad-argument-type fwd_only, # pyrefly: ignore # bad-argument-type [post_grad_patterns, joint_graph_patterns], @@ -68,6 +69,7 @@ def _misc_patterns_init(): # pyrefly: ignore # bad-argument-type randperm_index_replacement, [torch.empty(4, 8, device=device)], + # pyrefly: ignore # bad-argument-type fwd_only, # pyrefly: ignore # bad-argument-type [post_grad_patterns, joint_graph_patterns], diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 8d1b31eb4067..74fa91ccc75c 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -919,6 +919,7 @@ def _pad_mm_init() -> None: pattern, replacement, args, + # pyrefly: ignore # bad-argument-type joint_fwd_bwd, # pyrefly: ignore # bad-argument-type patterns, @@ -931,6 +932,7 @@ def _pad_mm_init() -> None: pattern, replacement, args, + # pyrefly: ignore # bad-argument-type fwd_only, # pyrefly: ignore # bad-argument-type patterns, diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index c9a83000d215..3efd96883c5b 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -666,6 +666,7 @@ def lazy_init(): prepare_softmax_replacement, [torch.empty(4, 8)], scalar_workaround=dict(dim=-1), + # pyrefly: ignore # bad-argument-type trace_fn=fwd_only, # pyrefly: ignore # bad-argument-type pass_dicts=pass_patterns[1], @@ -730,6 +731,7 @@ def register_lowering_pattern( return pattern_matcher.register_lowering_pattern( pattern, extra_check, + # pyrefly: ignore # bad-argument-type pass_dict=pass_patterns[pass_number], ) @@ -1573,6 +1575,7 @@ def register_partial_reduction_pattern(): @register_graph_pattern( MultiOutputPattern([partial_reduc, full_reduc]), + # pyrefly: ignore # bad-argument-type pass_dict=pass_patterns[2], ) def reuse_partial(match, input, reduced_dims, keepdim): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index aab0b346ed62..6df8f06cc02e 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -27,7 +27,7 @@ from torch._dynamo.utils import counters from torch._higher_order_ops.associative_scan import associative_scan_op from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation from torch._library.utils import get_layout_constraint_tag -from torch._prims_common import ( +from torch._prims_common import ( # pyrefly: ignore # deprecated canonicalize_dim, canonicalize_dims, check, diff --git a/torch/_library/opaque_object.py b/torch/_library/opaque_object.py index b3460fa2dda8..cbe8795ec531 100644 --- a/torch/_library/opaque_object.py +++ b/torch/_library/opaque_object.py @@ -173,6 +173,7 @@ def register_opaque_type(cls: Any, name: Optional[str] = None) -> None: f"Unable to accept name, {name}, for this opaque type as it contains a '.'" ) _OPAQUE_TYPES[cls] = name + # pyrefly: ignore # missing-attribute torch._C._register_opaque_type(name) @@ -182,4 +183,5 @@ def is_opaque_type(cls: Any) -> bool: """ if cls not in _OPAQUE_TYPES: return False + # pyrefly: ignore # missing-attribute return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls]) diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 8cc4c7993417..f8b5a7a75b2f 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -135,7 +135,7 @@ if is_available(): # this. # pyrefly: ignore # deprecated from .distributed_c10d import * # noqa: F403 - from .distributed_c10d import ( + from .distributed_c10d import ( # pyrefly: ignore # deprecated _all_gather_base, _coalescing_manager, _CoalescingManager, diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 5dd56fc006c4..70dc50f1591a 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -1009,8 +1009,8 @@ lib_impl.impl("broadcast", _broadcast_meta, "Meta") lib_impl.impl("broadcast_", _broadcast__meta, "Meta") # mark these ops has side effect so that they won't be removed by DCE -torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) -torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) +torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) # type: ignore[has-type] +torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) # type: ignore[has-type] # Register legacy ops for backward compatibility # TODO(yifu): remove these in functional collective beta release @@ -1176,7 +1176,7 @@ def all_gather_inplace( return tensor_list -from torch.distributed.distributed_c10d import ( +from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated _all_gather_base as legacy_all_gather_base, _reduce_scatter_base as legacy_reduce_scatter_base, all_gather as legacy_all_gather, @@ -1190,11 +1190,11 @@ from torch.distributed.distributed_c10d import ( # This dict should contain sets of functions that dynamo is allowed to remap. # Functions in this set should accept the same args/kwargs 1:1 as their mapping. traceable_collective_remaps = { - legacy_allgather: all_gather_tensor_inplace, - legacy_reducescatter: reduce_scatter_tensor_inplace, - legacy_allreduce: all_reduce_inplace, - legacy_all_to_all_single: all_to_all_inplace, - legacy_all_gather: all_gather_inplace, - legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, - legacy_all_gather_base: all_gather_tensor_inplace, + legacy_allgather: all_gather_tensor_inplace, # type: ignore[has-type] + legacy_reducescatter: reduce_scatter_tensor_inplace, # type: ignore[has-type] + legacy_allreduce: all_reduce_inplace, # type: ignore[has-type] + legacy_all_to_all_single: all_to_all_inplace, # type: ignore[has-type] + legacy_all_gather: all_gather_inplace, # type: ignore[has-type] + legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, # type: ignore[has-type] + legacy_all_gather_base: all_gather_tensor_inplace, # type: ignore[has-type] } diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index ee715b8afee6..d3ccbf7c5910 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -393,6 +393,7 @@ class LocalTensor(torch.Tensor): def __repr__(self) -> str: # type: ignore[override] parts = [] for k, v in self._local_tensors.items(): + # pyrefly: ignore # bad-argument-type parts.append(f" {k}: {v}") tensors_str = ",\n".join(parts) return f"LocalTensor(\n{tensors_str}\n)" @@ -680,6 +681,7 @@ class LocalTensorMode(TorchDispatchMode): def _unpatch_device_mesh(self) -> None: assert self._old_get_coordinate is not None DeviceMesh.get_coordinate = self._old_get_coordinate + # pyrefly: ignore # bad-assignment self._old_get_coordinate = None diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index f49a1e33ce24..43745218afd8 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -316,6 +316,7 @@ def _local_all_gather_( assert len(input_tensors) == 1 input_tensor = input_tensors[0] + # pyrefly: ignore # bad-assignment output_tensors = output_tensors[0] ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) @@ -336,10 +337,12 @@ def _local_all_gather_( source_tensor = input_tensor if isinstance(input_tensor, LocalTensor): source_tensor = input_tensor._local_tensors[rank_i] + # pyrefly: ignore # missing-attribute output_tensors[i].copy_(source_tensor) work = FakeWork() work_so = Work.boxed(work) + # pyrefly: ignore # bad-return return ([output_tensors], work_so) @@ -426,6 +429,7 @@ def _local_scatter_( assert len(output_tensors) == 1 assert len(input_tensors) == 1 output_tensor = output_tensors[0] + # pyrefly: ignore # bad-assignment input_tensors = input_tensors[0] ranks, group_offsets, offset = _prepare_collective_groups(process_group_so) diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index 3dbda8445cd7..e12f41c4858b 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -90,6 +90,7 @@ class DTensorSpec: if not isinstance(self.placements, tuple): self.placements = tuple(self.placements) if self.shard_order is None: + # pyrefly: ignore # bad-assignment self.shard_order = DTensorSpec.compute_default_shard_order(self.placements) self._hash: int | None = None diff --git a/torch/export/_trace.py b/torch/export/_trace.py index ee54cf07897e..803c9fc2080d 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -701,6 +701,7 @@ def _restore_state_dict( for name, _ in list( chain( original_module.named_parameters(remove_duplicate=False), + # pyrefly: ignore # bad-argument-type original_module.named_buffers(remove_duplicate=False), ) ): diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 27b81f49fe9c..a608020f30f3 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -218,6 +218,7 @@ class FlexKernelOptions(TypedDict, total=False): waves_per_eu: NotRequired[int] """ROCm-specific waves per execution unit.""" + # pyrefly: ignore # invalid-annotation force_flash: NotRequired[bool] """ If True, forces use of the cute-dsl flash attention kernel. diff --git a/torch/nn/utils/__init__.py b/torch/nn/utils/__init__.py index 84145da93f7b..ed9a83b13389 100644 --- a/torch/nn/utils/__init__.py +++ b/torch/nn/utils/__init__.py @@ -1,5 +1,5 @@ from . import parametrizations, parametrize, rnn, stateless -from .clip_grad import ( +from .clip_grad import ( # pyrefly: ignore # deprecated _clip_grads_with_norm_ as clip_grads_with_norm_, _get_total_norm as get_total_norm, clip_grad_norm, diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 9d6cc2a2b691..42cf898bfdf0 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -283,6 +283,7 @@ def clip_grad_value_( clip_value = float(clip_value) grads = [p.grad for p in parameters if p.grad is not None] + # pyrefly: ignore # bad-argument-type grouped_grads = _group_tensors_by_device_and_dtype([grads]) for (device, _), ([grads], _) in grouped_grads.items(): diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index e93458495617..5a48b690cfe0 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -111,8 +111,10 @@ class _Orthogonal(Module): Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) if hasattr(self, "base"): + # pyrefly: ignore # unbound-name Q = self.base @ Q if transposed: + # pyrefly: ignore # unbound-name Q = Q.mT return Q # type: ignore[possibly-undefined] diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 9bd1ffe74ad9..06b12d8b1931 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -170,6 +170,7 @@ class TorchTensor(ir.Tensor): if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): raise TypeError( + # pyrefly: ignore # missing-attribute f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " "with a tensor backed by real data using ONNXProgram.apply_weights() " "or save the model without initializers by setting include_initializers=False." diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 38f8585b1cc6..80674c0a39da 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -297,6 +297,7 @@ class AveragedModel(Module): avg_fn = get_swa_avg_fn() n_averaged = self.n_averaged.to(device) for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment] + # pyrefly: ignore # missing-attribute p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) else: for p_averaged, p_model in zip( # type: ignore[assignment] diff --git a/torch/quantization/_quantized_conversions.py b/torch/quantization/_quantized_conversions.py index 8d930c366c0d..54f40dcf7b25 100644 --- a/torch/quantization/_quantized_conversions.py +++ b/torch/quantization/_quantized_conversions.py @@ -71,6 +71,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( nrows // 16, 16 ) ).view(-1) + # pyrefly: ignore # unbound-name outp = outp.index_copy(1, cols_permuted, outp) # interleave_column_major_tensor diff --git a/torch/sparse/_semi_structured_ops.py b/torch/sparse/_semi_structured_ops.py index eed657550a7e..55cb0a8c113e 100644 --- a/torch/sparse/_semi_structured_ops.py +++ b/torch/sparse/_semi_structured_ops.py @@ -67,6 +67,7 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor: # Because we cannot go from the compressed representation back to the dense representation currently, # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix # is the first or second argument, we expect an even / odd number of calls to transpose respectively. + # pyrefly: ignore # no-matching-overload return self.__class__( torch.Size([self.shape[-1], self.shape[0]]), packed=self.packed_t, diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index be648fd84e7e..7fcdd8687933 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -184,6 +184,7 @@ class SparseSemiStructuredTensor(torch.Tensor): outer_stride, ) -> torch.Tensor: shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta + # pyrefly: ignore # no-matching-overload return cls( shape=shape, packed=inner_tensors.get("packed", None), @@ -413,6 +414,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): sparse_tensor_cutlass, meta_tensor_cutlass, ) = sparse_semi_structured_from_dense_cutlass(original_tensor) + # pyrefly: ignore # no-matching-overload return cls( original_tensor.shape, packed=sparse_tensor_cutlass, @@ -499,6 +501,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): original_tensor, algorithm=algorithm, use_cutlass=True ) + # pyrefly: ignore # no-matching-overload return cls( original_tensor.shape, packed=packed, @@ -560,6 +563,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): cls, original_tensor: torch.Tensor ) -> "SparseSemiStructuredTensorCUSPARSELT": cls._validate_device_dim_dtype_shape(original_tensor) + # pyrefly: ignore # no-matching-overload return cls( shape=original_tensor.shape, packed=torch._cslt_compress(original_tensor), @@ -626,6 +630,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): packed = packed.view(original_tensor.shape[0], -1) packed_t = packed_t.view(original_tensor.shape[1], -1) + # pyrefly: ignore # no-matching-overload return cls( original_tensor.shape, packed=packed, diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 8da9a0bef6b2..d7f65dd0c16e 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1336,6 +1336,7 @@ class Identity(sympy.Function): def _sympystr(self, printer): """Controls how sympy's StrPrinter prints this""" + # pyrefly: ignore # missing-attribute return f"({printer.doprint(self.args[0])})" def _eval_is_real(self):