[RFC] Add pyrefly to lintrunner (#165179)

This will add pyrefly to lint runner as a warning only - and allow us to collect feedback about the tool before switching to pyrefly as the main type checker.

References the steps outlined here: : https://github.com/pytorch/pytorch/issues/163283:

test plan:
`lintrunner init`
`lintrunner`
confirm when pyrefly errors are present results look like: https://gist.github.com/maggiemoss/e6cb2d015dd1ded560ae1329098cf33f

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165179
Approved by: https://github.com/ezyang
This commit is contained in:
Maggie Moss
2025-10-16 20:07:05 +00:00
committed by PyTorch MergeBot
parent 7df9aca529
commit d795fb225a
30 changed files with 357 additions and 17 deletions

View File

@ -118,9 +118,9 @@ jobs:
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
echo "Running all other linters" echo "Running all other linters"
if [ "$CHANGED_FILES" = '*' ]; then if [ "$CHANGED_FILES" = '*' ]; then
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
else else
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
fi fi
quick-checks: quick-checks:

View File

@ -209,6 +209,46 @@ command = [
'@{{PATHSFILE}}' '@{{PATHSFILE}}'
] ]
[[linter]]
code = 'PYREFLY'
include_patterns = [
'torch/**/*.py',
'torch/**/*.pyi',
'torchgen/**/*.py',
'torchgen/**/*.pyi',
'functorch/**/*.py',
'functorch/**/*.pyi',
]
exclude_patterns = []
command = [
'python3',
'tools/linter/adapters/pyrefly_linter.py',
'--config=pyrefly.toml',
]
init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'pyrefly==0.36.2',
'sympy==1.13.3',
'types-requests==2.27.25',
'types-pyyaml==6.0.2',
'types-tabulate==0.8.8',
'types-protobuf==5.29.1.20250403',
'types-setuptools==79.0.0.20250422',
'types-jinja2==2.11.9',
'types-colorama==0.4.6',
'filelock==3.18.0',
'junitparser==2.1.1',
'rich==14.1.0',
'optree==0.17.0',
'types-openpyxl==3.1.5.20250919',
'types-python-dateutil==2.9.0.20251008'
]
[[linter]] [[linter]]
code = 'CLANGTIDY' code = 'CLANGTIDY'
include_patterns = [ include_patterns = [

View File

@ -1,5 +1,7 @@
# A Pyrefly configuration for PyTorch # A Pyrefly configuration for PyTorch
# Based on https://github.com/pytorch/pytorch/blob/main/mypy.ini # Based on https://github.com/pytorch/pytorch/blob/main/mypy.ini
python-version = "3.12"
project-includes = [ project-includes = [
"torch", "torch",
"caffe2", "caffe2",
@ -36,6 +38,7 @@ project-excludes = [
"torch/nn/modules/rnn.py", # only remove when parsing errors are fixed "torch/nn/modules/rnn.py", # only remove when parsing errors are fixed
"torch/_inductor/codecache.py", "torch/_inductor/codecache.py",
"torch/distributed/elastic/metrics/__init__.py", "torch/distributed/elastic/metrics/__init__.py",
"torch/_inductor/fx_passes/bucketing.py",
# ==== # ====
"benchmarks/instruction_counts/main.py", "benchmarks/instruction_counts/main.py",
"benchmarks/instruction_counts/definitions/setup.py", "benchmarks/instruction_counts/definitions/setup.py",

View File

@ -0,0 +1,258 @@
from __future__ import annotations
import argparse
import json
import logging
import os
import re
import subprocess
import sys
import time
from enum import Enum
from typing import NamedTuple
class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"
class LintMessage(NamedTuple):
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: str | None
replacement: str | None
description: str | None
# Note: This regex pattern is kept for reference but not used for pyrefly JSON parsing
RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
(?P<line>\d+):
(?:(?P<column>-?\d+):)?
\s(?P<severity>\S+?):?
\s(?P<message>.*)
\s(?P<code>\[.*\])
$
"""
)
# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
INTERNAL_ERROR_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
(?P<line>\d+):
\s(?P<severity>\S+?):?
\s(?P<message>INTERNAL\sERROR.*)
$
"""
)
def run_command(
args: list[str],
*,
extra_env: dict[str, str] | None,
retries: int,
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
return subprocess.run(
args,
capture_output=True,
)
finally:
end_time = time.monotonic()
logging.debug("took %dms", (end_time - start_time) * 1000)
# Severity mapping (currently only used for stderr internal errors)
# Pyrefly JSON output doesn't include severity, so all errors default to ERROR
severities = {
"error": LintSeverity.ERROR,
"note": LintSeverity.ADVICE,
}
def check_pyrefly_installed(code: str) -> list[LintMessage]:
cmd = ["pyrefly", "--version"]
try:
subprocess.run(cmd, check=True, capture_output=True)
return []
except subprocess.CalledProcessError as e:
msg = e.stderr.decode(errors="replace")
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="command-failed",
original=None,
replacement=None,
description=f"Could not run '{' '.join(cmd)}': {msg}",
)
]
def in_github_actions() -> bool:
return bool(os.getenv("GITHUB_ACTIONS"))
def check_files(
code: str,
config: str,
) -> list[LintMessage]:
try:
pyrefly_commands = [
"pyrefly",
"check",
"--config",
config,
"--output-format=json",
]
proc = run_command(
[*pyrefly_commands],
extra_env={},
retries=0,
)
except OSError as err:
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="command-failed",
original=None,
replacement=None,
description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
)
]
stdout = str(proc.stdout, "utf-8").strip()
stderr = str(proc.stderr, "utf-8").strip()
if proc.returncode not in (0, 1):
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="command-failed",
original=None,
replacement=None,
description=stderr,
)
]
# Parse JSON output from pyrefly
try:
if stdout:
result = json.loads(stdout)
errors = result.get("errors", [])
else:
errors = []
# For now filter out deprecated warnings and only report type errors as warnings
# until we remove mypy
errors = [error for error in errors if error["name"] != "deprecated"]
rc = [
LintMessage(
path=error["path"],
name=error["name"],
description=error.get(
"description", error.get("concise_description", "")
),
line=error["line"],
char=error["column"],
code=code,
severity=LintSeverity.ADVICE,
# uncomment and replace when we switch to pyrefly
# severity=LintSeverity.ADVICE if error["name"] == "deprecated" else LintSeverity.ERROR,
original=None,
replacement=None,
)
for error in errors
]
except (json.JSONDecodeError, KeyError, TypeError) as e:
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="json-parse-error",
original=None,
replacement=None,
description=f"Failed to parse pyrefly JSON output: {e}",
)
]
# Still check stderr for internal errors
rc += [
LintMessage(
path=match["file"],
name="INTERNAL ERROR",
description=match["message"],
line=int(match["line"]),
char=None,
code=code,
severity=severities.get(match["severity"], LintSeverity.ERROR),
original=None,
replacement=None,
)
for match in INTERNAL_ERROR_RE.finditer(stderr)
]
return rc
def main() -> None:
parser = argparse.ArgumentParser(
description="pyrefly wrapper linter.",
fromfile_prefix_chars="@",
)
parser.add_argument(
"--code",
default="PYREFLY",
help="the code this lint should report as",
)
parser.add_argument(
"--verbose",
action="store_true",
help="verbose logging",
)
parser.add_argument(
"--config",
required=True,
help="path to an mypy .ini config file",
)
args = parser.parse_args()
logging.basicConfig(
format="<%(threadName)s:%(levelname)s> %(message)s",
level=logging.INFO,
stream=sys.stderr,
)
lint_messages = check_pyrefly_installed(args.code) + check_files(
args.code, args.config
)
for lint_message in lint_messages:
print(json.dumps(lint_message._asdict()), flush=True)
if __name__ == "__main__":
main()

View File

@ -76,6 +76,7 @@ class StreamVariable(VariableTracker):
super().__init__(**kwargs) super().__init__(**kwargs)
self.proxy = proxy self.proxy = proxy
self.value = value self.value = value
# pyrefly: ignore # read-only
self.device = device self.device = device
def python_type(self) -> type: def python_type(self) -> type:

View File

@ -1492,6 +1492,7 @@ def _aot_stage2a_partition(
# apply joint_gm callback here # apply joint_gm callback here
if callable(torch._functorch.config.joint_custom_pass): if callable(torch._functorch.config.joint_custom_pass):
# pyrefly: ignore # bad-assignment
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs) fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
static_lifetime_input_indices = fw_metadata.static_input_indices static_lifetime_input_indices = fw_metadata.static_input_indices
@ -1761,6 +1762,7 @@ def _aot_stage2b_bw_compile(
# tensor which is wrong. # tensor which is wrong.
ph_size = ph_arg.size() ph_size = ph_arg.size()
# pyrefly: ignore # bad-argument-type
if len(ph_size) == 0 and len(real_stride) > 0: if len(ph_size) == 0 and len(real_stride) > 0:
# Fix for 0-dimensional tensors: When a tensor becomes 0-d # Fix for 0-dimensional tensors: When a tensor becomes 0-d
# (e.g., via squeeze), its stride should be () not (1,). # (e.g., via squeeze), its stride should be () not (1,).

View File

@ -628,6 +628,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs) position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
] ]
# add the scale nodes to the output find the first sym_node in the output # add the scale nodes to the output find the first sym_node in the output
# pyrefly: ignore # bad-argument-type
idx = find_first_sym_node(output_updated_args) idx = find_first_sym_node(output_updated_args)
scale_nodes = tensor_scale_nodes + sym_scale_nodes scale_nodes = tensor_scale_nodes + sym_scale_nodes
if scale_nodes: if scale_nodes:

View File

@ -86,7 +86,7 @@ def bucket_all_gather(
mode: BucketMode = "default", mode: BucketMode = "default",
) -> None: ) -> None:
if bucket_cap_mb_by_bucket_idx is None: if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import ( from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
bucket_cap_mb_by_bucket_idx_default, bucket_cap_mb_by_bucket_idx_default,
) )
@ -103,7 +103,7 @@ def bucket_reduce_scatter(
mode: BucketMode = "default", mode: BucketMode = "default",
) -> None: ) -> None:
if bucket_cap_mb_by_bucket_idx is None: if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import ( from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
bucket_cap_mb_by_bucket_idx_default, bucket_cap_mb_by_bucket_idx_default,
) )

View File

@ -209,6 +209,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
int8_woq_fusion_replacement, int8_woq_fusion_replacement,
[val(), val(), val(), val(), scale(), scale(), scale()], [val(), val(), val(), val(), scale(), scale(), scale()],
# pyrefly: ignore # bad-argument-type
fwd_only, fwd_only,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
pass_patterns[0], pass_patterns[0],
@ -230,6 +231,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
matmul_replacement, matmul_replacement,
[val(), val(), val(), val()], [val(), val(), val(), val()],
# pyrefly: ignore # bad-argument-type
fwd_only, fwd_only,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
pass_patterns[0], pass_patterns[0],
@ -251,6 +253,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
matmul_replacement_two, matmul_replacement_two,
[val(), val(), val()], [val(), val(), val()],
# pyrefly: ignore # bad-argument-type
fwd_only, fwd_only,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
pass_patterns[0], pass_patterns[0],
@ -276,6 +279,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
addmm_fuse_replacement_second, addmm_fuse_replacement_second,
[val() for _ in range(7)], [val() for _ in range(7)],
# pyrefly: ignore # bad-argument-type
fwd_only, fwd_only,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
pass_patterns[0], pass_patterns[0],

View File

@ -49,6 +49,7 @@ def _misc_patterns_init():
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
randperm_index_add_replacement, randperm_index_add_replacement,
[torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
# pyrefly: ignore # bad-argument-type
fwd_only, fwd_only,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
[post_grad_patterns, joint_graph_patterns], [post_grad_patterns, joint_graph_patterns],
@ -68,6 +69,7 @@ def _misc_patterns_init():
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
randperm_index_replacement, randperm_index_replacement,
[torch.empty(4, 8, device=device)], [torch.empty(4, 8, device=device)],
# pyrefly: ignore # bad-argument-type
fwd_only, fwd_only,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
[post_grad_patterns, joint_graph_patterns], [post_grad_patterns, joint_graph_patterns],

View File

@ -919,6 +919,7 @@ def _pad_mm_init() -> None:
pattern, pattern,
replacement, replacement,
args, args,
# pyrefly: ignore # bad-argument-type
joint_fwd_bwd, joint_fwd_bwd,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
patterns, patterns,
@ -931,6 +932,7 @@ def _pad_mm_init() -> None:
pattern, pattern,
replacement, replacement,
args, args,
# pyrefly: ignore # bad-argument-type
fwd_only, fwd_only,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
patterns, patterns,

View File

@ -666,6 +666,7 @@ def lazy_init():
prepare_softmax_replacement, prepare_softmax_replacement,
[torch.empty(4, 8)], [torch.empty(4, 8)],
scalar_workaround=dict(dim=-1), scalar_workaround=dict(dim=-1),
# pyrefly: ignore # bad-argument-type
trace_fn=fwd_only, trace_fn=fwd_only,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
pass_dicts=pass_patterns[1], pass_dicts=pass_patterns[1],
@ -730,6 +731,7 @@ def register_lowering_pattern(
return pattern_matcher.register_lowering_pattern( return pattern_matcher.register_lowering_pattern(
pattern, pattern,
extra_check, extra_check,
# pyrefly: ignore # bad-argument-type
pass_dict=pass_patterns[pass_number], pass_dict=pass_patterns[pass_number],
) )
@ -1573,6 +1575,7 @@ def register_partial_reduction_pattern():
@register_graph_pattern( @register_graph_pattern(
MultiOutputPattern([partial_reduc, full_reduc]), MultiOutputPattern([partial_reduc, full_reduc]),
# pyrefly: ignore # bad-argument-type
pass_dict=pass_patterns[2], pass_dict=pass_patterns[2],
) )
def reuse_partial(match, input, reduced_dims, keepdim): def reuse_partial(match, input, reduced_dims, keepdim):

View File

@ -27,7 +27,7 @@ from torch._dynamo.utils import counters
from torch._higher_order_ops.associative_scan import associative_scan_op from torch._higher_order_ops.associative_scan import associative_scan_op
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
from torch._library.utils import get_layout_constraint_tag from torch._library.utils import get_layout_constraint_tag
from torch._prims_common import ( from torch._prims_common import ( # pyrefly: ignore # deprecated
canonicalize_dim, canonicalize_dim,
canonicalize_dims, canonicalize_dims,
check, check,

View File

@ -173,6 +173,7 @@ def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
f"Unable to accept name, {name}, for this opaque type as it contains a '.'" f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
) )
_OPAQUE_TYPES[cls] = name _OPAQUE_TYPES[cls] = name
# pyrefly: ignore # missing-attribute
torch._C._register_opaque_type(name) torch._C._register_opaque_type(name)
@ -182,4 +183,5 @@ def is_opaque_type(cls: Any) -> bool:
""" """
if cls not in _OPAQUE_TYPES: if cls not in _OPAQUE_TYPES:
return False return False
# pyrefly: ignore # missing-attribute
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls]) return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])

View File

@ -135,7 +135,7 @@ if is_available():
# this. # this.
# pyrefly: ignore # deprecated # pyrefly: ignore # deprecated
from .distributed_c10d import * # noqa: F403 from .distributed_c10d import * # noqa: F403
from .distributed_c10d import ( from .distributed_c10d import ( # pyrefly: ignore # deprecated
_all_gather_base, _all_gather_base,
_coalescing_manager, _coalescing_manager,
_CoalescingManager, _CoalescingManager,

View File

@ -1009,8 +1009,8 @@ lib_impl.impl("broadcast", _broadcast_meta, "Meta")
lib_impl.impl("broadcast_", _broadcast__meta, "Meta") lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
# mark these ops has side effect so that they won't be removed by DCE # mark these ops has side effect so that they won't be removed by DCE
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) # type: ignore[has-type]
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) # type: ignore[has-type]
# Register legacy ops for backward compatibility # Register legacy ops for backward compatibility
# TODO(yifu): remove these in functional collective beta release # TODO(yifu): remove these in functional collective beta release
@ -1176,7 +1176,7 @@ def all_gather_inplace(
return tensor_list return tensor_list
from torch.distributed.distributed_c10d import ( from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated
_all_gather_base as legacy_all_gather_base, _all_gather_base as legacy_all_gather_base,
_reduce_scatter_base as legacy_reduce_scatter_base, _reduce_scatter_base as legacy_reduce_scatter_base,
all_gather as legacy_all_gather, all_gather as legacy_all_gather,
@ -1190,11 +1190,11 @@ from torch.distributed.distributed_c10d import (
# This dict should contain sets of functions that dynamo is allowed to remap. # This dict should contain sets of functions that dynamo is allowed to remap.
# Functions in this set should accept the same args/kwargs 1:1 as their mapping. # Functions in this set should accept the same args/kwargs 1:1 as their mapping.
traceable_collective_remaps = { traceable_collective_remaps = {
legacy_allgather: all_gather_tensor_inplace, legacy_allgather: all_gather_tensor_inplace, # type: ignore[has-type]
legacy_reducescatter: reduce_scatter_tensor_inplace, legacy_reducescatter: reduce_scatter_tensor_inplace, # type: ignore[has-type]
legacy_allreduce: all_reduce_inplace, legacy_allreduce: all_reduce_inplace, # type: ignore[has-type]
legacy_all_to_all_single: all_to_all_inplace, legacy_all_to_all_single: all_to_all_inplace, # type: ignore[has-type]
legacy_all_gather: all_gather_inplace, legacy_all_gather: all_gather_inplace, # type: ignore[has-type]
legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, # type: ignore[has-type]
legacy_all_gather_base: all_gather_tensor_inplace, legacy_all_gather_base: all_gather_tensor_inplace, # type: ignore[has-type]
} }

View File

@ -393,6 +393,7 @@ class LocalTensor(torch.Tensor):
def __repr__(self) -> str: # type: ignore[override] def __repr__(self) -> str: # type: ignore[override]
parts = [] parts = []
for k, v in self._local_tensors.items(): for k, v in self._local_tensors.items():
# pyrefly: ignore # bad-argument-type
parts.append(f" {k}: {v}") parts.append(f" {k}: {v}")
tensors_str = ",\n".join(parts) tensors_str = ",\n".join(parts)
return f"LocalTensor(\n{tensors_str}\n)" return f"LocalTensor(\n{tensors_str}\n)"
@ -680,6 +681,7 @@ class LocalTensorMode(TorchDispatchMode):
def _unpatch_device_mesh(self) -> None: def _unpatch_device_mesh(self) -> None:
assert self._old_get_coordinate is not None assert self._old_get_coordinate is not None
DeviceMesh.get_coordinate = self._old_get_coordinate DeviceMesh.get_coordinate = self._old_get_coordinate
# pyrefly: ignore # bad-assignment
self._old_get_coordinate = None self._old_get_coordinate = None

View File

@ -316,6 +316,7 @@ def _local_all_gather_(
assert len(input_tensors) == 1 assert len(input_tensors) == 1
input_tensor = input_tensors[0] input_tensor = input_tensors[0]
# pyrefly: ignore # bad-assignment
output_tensors = output_tensors[0] output_tensors = output_tensors[0]
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
@ -336,10 +337,12 @@ def _local_all_gather_(
source_tensor = input_tensor source_tensor = input_tensor
if isinstance(input_tensor, LocalTensor): if isinstance(input_tensor, LocalTensor):
source_tensor = input_tensor._local_tensors[rank_i] source_tensor = input_tensor._local_tensors[rank_i]
# pyrefly: ignore # missing-attribute
output_tensors[i].copy_(source_tensor) output_tensors[i].copy_(source_tensor)
work = FakeWork() work = FakeWork()
work_so = Work.boxed(work) work_so = Work.boxed(work)
# pyrefly: ignore # bad-return
return ([output_tensors], work_so) return ([output_tensors], work_so)
@ -426,6 +429,7 @@ def _local_scatter_(
assert len(output_tensors) == 1 assert len(output_tensors) == 1
assert len(input_tensors) == 1 assert len(input_tensors) == 1
output_tensor = output_tensors[0] output_tensor = output_tensors[0]
# pyrefly: ignore # bad-assignment
input_tensors = input_tensors[0] input_tensors = input_tensors[0]
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so) ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)

View File

@ -90,6 +90,7 @@ class DTensorSpec:
if not isinstance(self.placements, tuple): if not isinstance(self.placements, tuple):
self.placements = tuple(self.placements) self.placements = tuple(self.placements)
if self.shard_order is None: if self.shard_order is None:
# pyrefly: ignore # bad-assignment
self.shard_order = DTensorSpec.compute_default_shard_order(self.placements) self.shard_order = DTensorSpec.compute_default_shard_order(self.placements)
self._hash: int | None = None self._hash: int | None = None

View File

@ -701,6 +701,7 @@ def _restore_state_dict(
for name, _ in list( for name, _ in list(
chain( chain(
original_module.named_parameters(remove_duplicate=False), original_module.named_parameters(remove_duplicate=False),
# pyrefly: ignore # bad-argument-type
original_module.named_buffers(remove_duplicate=False), original_module.named_buffers(remove_duplicate=False),
) )
): ):

View File

@ -218,6 +218,7 @@ class FlexKernelOptions(TypedDict, total=False):
waves_per_eu: NotRequired[int] waves_per_eu: NotRequired[int]
"""ROCm-specific waves per execution unit.""" """ROCm-specific waves per execution unit."""
# pyrefly: ignore # invalid-annotation
force_flash: NotRequired[bool] force_flash: NotRequired[bool]
""" If True, forces use of the cute-dsl flash attention kernel. """ If True, forces use of the cute-dsl flash attention kernel.

View File

@ -1,5 +1,5 @@
from . import parametrizations, parametrize, rnn, stateless from . import parametrizations, parametrize, rnn, stateless
from .clip_grad import ( from .clip_grad import ( # pyrefly: ignore # deprecated
_clip_grads_with_norm_ as clip_grads_with_norm_, _clip_grads_with_norm_ as clip_grads_with_norm_,
_get_total_norm as get_total_norm, _get_total_norm as get_total_norm,
clip_grad_norm, clip_grad_norm,

View File

@ -283,6 +283,7 @@ def clip_grad_value_(
clip_value = float(clip_value) clip_value = float(clip_value)
grads = [p.grad for p in parameters if p.grad is not None] grads = [p.grad for p in parameters if p.grad is not None]
# pyrefly: ignore # bad-argument-type
grouped_grads = _group_tensors_by_device_and_dtype([grads]) grouped_grads = _group_tensors_by_device_and_dtype([grads])
for (device, _), ([grads], _) in grouped_grads.items(): for (device, _), ([grads], _) in grouped_grads.items():

View File

@ -111,8 +111,10 @@ class _Orthogonal(Module):
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
if hasattr(self, "base"): if hasattr(self, "base"):
# pyrefly: ignore # unbound-name
Q = self.base @ Q Q = self.base @ Q
if transposed: if transposed:
# pyrefly: ignore # unbound-name
Q = Q.mT Q = Q.mT
return Q # type: ignore[possibly-undefined] return Q # type: ignore[possibly-undefined]

View File

@ -170,6 +170,7 @@ class TorchTensor(ir.Tensor):
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
raise TypeError( raise TypeError(
# pyrefly: ignore # missing-attribute
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
"with a tensor backed by real data using ONNXProgram.apply_weights() " "with a tensor backed by real data using ONNXProgram.apply_weights() "
"or save the model without initializers by setting include_initializers=False." "or save the model without initializers by setting include_initializers=False."

View File

@ -297,6 +297,7 @@ class AveragedModel(Module):
avg_fn = get_swa_avg_fn() avg_fn = get_swa_avg_fn()
n_averaged = self.n_averaged.to(device) n_averaged = self.n_averaged.to(device)
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment] for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
# pyrefly: ignore # missing-attribute
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
else: else:
for p_averaged, p_model in zip( # type: ignore[assignment] for p_averaged, p_model in zip( # type: ignore[assignment]

View File

@ -71,6 +71,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
nrows // 16, 16 nrows // 16, 16
) )
).view(-1) ).view(-1)
# pyrefly: ignore # unbound-name
outp = outp.index_copy(1, cols_permuted, outp) outp = outp.index_copy(1, cols_permuted, outp)
# interleave_column_major_tensor # interleave_column_major_tensor

View File

@ -67,6 +67,7 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
# Because we cannot go from the compressed representation back to the dense representation currently, # Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively. # is the first or second argument, we expect an even / odd number of calls to transpose respectively.
# pyrefly: ignore # no-matching-overload
return self.__class__( return self.__class__(
torch.Size([self.shape[-1], self.shape[0]]), torch.Size([self.shape[-1], self.shape[0]]),
packed=self.packed_t, packed=self.packed_t,

View File

@ -184,6 +184,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
outer_stride, outer_stride,
) -> torch.Tensor: ) -> torch.Tensor:
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
# pyrefly: ignore # no-matching-overload
return cls( return cls(
shape=shape, shape=shape,
packed=inner_tensors.get("packed", None), packed=inner_tensors.get("packed", None),
@ -413,6 +414,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
sparse_tensor_cutlass, sparse_tensor_cutlass,
meta_tensor_cutlass, meta_tensor_cutlass,
) = sparse_semi_structured_from_dense_cutlass(original_tensor) ) = sparse_semi_structured_from_dense_cutlass(original_tensor)
# pyrefly: ignore # no-matching-overload
return cls( return cls(
original_tensor.shape, original_tensor.shape,
packed=sparse_tensor_cutlass, packed=sparse_tensor_cutlass,
@ -499,6 +501,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
original_tensor, algorithm=algorithm, use_cutlass=True original_tensor, algorithm=algorithm, use_cutlass=True
) )
# pyrefly: ignore # no-matching-overload
return cls( return cls(
original_tensor.shape, original_tensor.shape,
packed=packed, packed=packed,
@ -560,6 +563,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
cls, original_tensor: torch.Tensor cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUSPARSELT": ) -> "SparseSemiStructuredTensorCUSPARSELT":
cls._validate_device_dim_dtype_shape(original_tensor) cls._validate_device_dim_dtype_shape(original_tensor)
# pyrefly: ignore # no-matching-overload
return cls( return cls(
shape=original_tensor.shape, shape=original_tensor.shape,
packed=torch._cslt_compress(original_tensor), packed=torch._cslt_compress(original_tensor),
@ -626,6 +630,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
packed = packed.view(original_tensor.shape[0], -1) packed = packed.view(original_tensor.shape[0], -1)
packed_t = packed_t.view(original_tensor.shape[1], -1) packed_t = packed_t.view(original_tensor.shape[1], -1)
# pyrefly: ignore # no-matching-overload
return cls( return cls(
original_tensor.shape, original_tensor.shape,
packed=packed, packed=packed,

View File

@ -1336,6 +1336,7 @@ class Identity(sympy.Function):
def _sympystr(self, printer): def _sympystr(self, printer):
"""Controls how sympy's StrPrinter prints this""" """Controls how sympy's StrPrinter prints this"""
# pyrefly: ignore # missing-attribute
return f"({printer.doprint(self.args[0])})" return f"({printer.doprint(self.args[0])})"
def _eval_is_real(self): def _eval_is_real(self):