Compare commits

...

11 Commits

26 changed files with 327 additions and 6 deletions

View File

@ -209,6 +209,46 @@ command = [
'@{{PATHSFILE}}'
]
[[linter]]
code = 'PYREFLY'
include_patterns = [
'torch/**/*.py',
'torch/**/*.pyi',
'torchgen/**/*.py',
'torchgen/**/*.pyi',
'functorch/**/*.py',
'functorch/**/*.pyi',
]
exclude_patterns = []
command = [
'python3',
'tools/linter/adapters/pyrefly_linter.py',
]
init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'pyrefly==0.36.2',
'sympy==1.13.3',
'types-requests==2.27.25',
'types-pyyaml==6.0.2',
'types-tabulate==0.8.8',
'types-protobuf==5.29.1.20250403',
'types-setuptools==79.0.0.20250422',
'types-jinja2==2.11.9',
'types-colorama==0.4.6',
'filelock==3.18.0',
'junitparser==2.1.1',
'rich==14.1.0',
'optree==0.17.0',
'types-openpyxl==3.1.5.20250919',
'types-python-dateutil==2.9.0.20251008'
]
[[linter]]
code = 'CLANGTIDY'
include_patterns = [

View File

@ -36,6 +36,7 @@ project-excludes = [
"torch/nn/modules/rnn.py", # only remove when parsing errors are fixed
"torch/_inductor/codecache.py",
"torch/distributed/elastic/metrics/__init__.py",
"torch/_inductor/fx_passes/bucketing.py",
# ====
"benchmarks/instruction_counts/main.py",
"benchmarks/instruction_counts/definitions/setup.py",

View File

@ -0,0 +1,244 @@
from __future__ import annotations
import argparse
import json
import logging
import os
import re
import subprocess
import sys
import time
from enum import Enum
from typing import NamedTuple
class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"
class LintMessage(NamedTuple):
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: str | None
replacement: str | None
description: str | None
# Note: This regex pattern is kept for reference but not used for pyrefly JSON parsing
RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
(?P<line>\d+):
(?:(?P<column>-?\d+):)?
\s(?P<severity>\S+?):?
\s(?P<message>.*)
\s(?P<code>\[.*\])
$
"""
)
# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
INTERNAL_ERROR_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
(?P<line>\d+):
\s(?P<severity>\S+?):?
\s(?P<message>INTERNAL\sERROR.*)
$
"""
)
def run_command(
args: list[str],
*,
extra_env: dict[str, str] | None,
retries: int,
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
return subprocess.run(
args,
capture_output=True,
)
finally:
end_time = time.monotonic()
logging.debug("took %dms", (end_time - start_time) * 1000)
# Severity mapping (currently only used for stderr internal errors)
# Pyrefly JSON output doesn't include severity, so all errors default to ERROR
severities = {
"error": LintSeverity.ERROR,
"note": LintSeverity.ADVICE,
}
def check_pyrefly_installed(code: str) -> list[LintMessage]:
cmd = ["pyrefly", "--version"]
try:
subprocess.run(cmd, check=True, capture_output=True)
return []
except subprocess.CalledProcessError as e:
msg = e.stderr.decode(errors="replace")
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="command-failed",
original=None,
replacement=None,
description=f"Could not run '{' '.join(cmd)}': {msg}",
)
]
def in_github_actions() -> bool:
return bool(os.getenv("GITHUB_ACTIONS"))
def check_files(
code: str,
) -> list[LintMessage]:
try:
pyrefly_commands = ["pyrefly", "check", "--output-format=json"]
proc = run_command(
[*pyrefly_commands],
extra_env={},
retries=0,
)
except OSError as err:
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="command-failed",
original=None,
replacement=None,
description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
)
]
stdout = str(proc.stdout, "utf-8").strip()
stderr = str(proc.stderr, "utf-8").strip()
if proc.returncode not in (0, 1):
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="command-failed",
original=None,
replacement=None,
description=stderr,
)
]
# Parse JSON output from pyrefly
try:
if stdout:
result = json.loads(stdout)
errors = result.get("errors", [])
else:
errors = []
# For now filter out deprecated warnings and only report type errors as warnings
# until we remove mypy
errors = [error for error in errors if error["name"] != "deprecated"]
rc = [
LintMessage(
path=error["path"],
name=error["name"],
description=error.get(
"description", error.get("concise_description", "")
),
line=error["line"],
char=error["column"],
code=code,
severity=LintSeverity.ADVICE,
# uncomment and replace when we switch to pyrefly
# severity=LintSeverity.ADVICE if error["name"] == "deprecated" else LintSeverity.ERROR,
original=None,
replacement=None,
)
for error in errors
]
except (json.JSONDecodeError, KeyError, TypeError) as e:
return [
LintMessage(
path=None,
line=None,
char=None,
code=code,
severity=LintSeverity.ERROR,
name="json-parse-error",
original=None,
replacement=None,
description=f"Failed to parse pyrefly JSON output: {e}",
)
]
# Still check stderr for internal errors
rc += [
LintMessage(
path=match["file"],
name="INTERNAL ERROR",
description=match["message"],
line=int(match["line"]),
char=None,
code=code,
severity=severities.get(match["severity"], LintSeverity.ERROR),
original=None,
replacement=None,
)
for match in INTERNAL_ERROR_RE.finditer(stderr)
]
return rc
def main() -> None:
parser = argparse.ArgumentParser(
description="pyrefly wrapper linter.",
fromfile_prefix_chars="@",
)
parser.add_argument(
"--code",
default="PYREFLY",
help="the code this lint should report as",
)
parser.add_argument(
"--verbose",
action="store_true",
help="verbose logging",
)
args = parser.parse_args()
logging.basicConfig(
format="<%(threadName)s:%(levelname)s> %(message)s",
level=logging.INFO,
stream=sys.stderr,
)
lint_messages = check_pyrefly_installed(args.code) + check_files(args.code)
for lint_message in lint_messages:
print(json.dumps(lint_message._asdict()), flush=True)
if __name__ == "__main__":
main()

View File

@ -1559,6 +1559,7 @@ def _aot_stage2a_partition(
# apply joint_gm callback here
if callable(torch._functorch.config.joint_custom_pass):
# pyrefly: ignore # bad-assignment
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
static_lifetime_input_indices = fw_metadata.static_input_indices
@ -1901,6 +1902,7 @@ def _aot_stage2b_bw_compile(
# tensor which is wrong.
ph_size = ph_arg.size()
# pyrefly: ignore # bad-argument-type
if len(ph_size) == 0 and len(real_stride) > 0:
# Fix for 0-dimensional tensors: When a tensor becomes 0-d
# (e.g., via squeeze), its stride should be () not (1,).

View File

@ -53,7 +53,7 @@ def bucket_all_gather(
mode: Optional[str] = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
bucket_cap_mb_by_bucket_idx_default,
)
@ -70,7 +70,7 @@ def bucket_reduce_scatter(
mode: Optional[str] = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
bucket_cap_mb_by_bucket_idx_default,
)

View File

@ -209,6 +209,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type
int8_woq_fusion_replacement,
[val(), val(), val(), val(), scale(), scale(), scale()],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
pass_patterns[0],
@ -230,6 +231,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type
matmul_replacement,
[val(), val(), val(), val()],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
pass_patterns[0],
@ -251,6 +253,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type
matmul_replacement_two,
[val(), val(), val()],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
pass_patterns[0],
@ -276,6 +279,7 @@ def addmm_patterns_init():
# pyrefly: ignore # bad-argument-type
addmm_fuse_replacement_second,
[val() for _ in range(7)],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
pass_patterns[0],

View File

@ -49,6 +49,7 @@ def _misc_patterns_init():
# pyrefly: ignore # bad-argument-type
randperm_index_add_replacement,
[torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
[post_grad_patterns, joint_graph_patterns],
@ -68,6 +69,7 @@ def _misc_patterns_init():
# pyrefly: ignore # bad-argument-type
randperm_index_replacement,
[torch.empty(4, 8, device=device)],
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
[post_grad_patterns, joint_graph_patterns],

View File

@ -921,6 +921,7 @@ def _pad_mm_init() -> None:
pattern,
replacement,
args,
# pyrefly: ignore # bad-argument-type
joint_fwd_bwd,
# pyrefly: ignore # bad-argument-type
patterns,
@ -933,6 +934,7 @@ def _pad_mm_init() -> None:
pattern,
replacement,
args,
# pyrefly: ignore # bad-argument-type
fwd_only,
# pyrefly: ignore # bad-argument-type
patterns,

View File

@ -666,6 +666,7 @@ def lazy_init():
prepare_softmax_replacement,
[torch.empty(4, 8)],
scalar_workaround=dict(dim=-1),
# pyrefly: ignore # bad-argument-type
trace_fn=fwd_only,
# pyrefly: ignore # bad-argument-type
pass_dicts=pass_patterns[1],
@ -730,6 +731,7 @@ def register_lowering_pattern(
return pattern_matcher.register_lowering_pattern(
pattern,
extra_check,
# pyrefly: ignore # bad-argument-type
pass_dict=pass_patterns[pass_number],
)
@ -1573,6 +1575,7 @@ def register_partial_reduction_pattern():
@register_graph_pattern(
MultiOutputPattern([partial_reduc, full_reduc]),
# pyrefly: ignore # bad-argument-type
pass_dict=pass_patterns[2],
)
def reuse_partial(match, input, reduced_dims, keepdim):

View File

@ -27,7 +27,7 @@ from torch._dynamo.utils import counters
from torch._higher_order_ops.associative_scan import associative_scan_op
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
from torch._library.utils import get_layout_constraint_tag
from torch._prims_common import (
from torch._prims_common import ( # pyrefly: ignore # deprecated
canonicalize_dim,
canonicalize_dims,
check,

View File

@ -173,6 +173,7 @@ def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
)
_OPAQUE_TYPES[cls] = name
# pyrefly: ignore # missing-attribute
torch._C._register_opaque_type(name)
@ -182,4 +183,5 @@ def is_opaque_type(cls: Any) -> bool:
"""
if cls not in _OPAQUE_TYPES:
return False
# pyrefly: ignore # missing-attribute
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])

View File

@ -135,7 +135,7 @@ if is_available():
# this.
# pyrefly: ignore # deprecated
from .distributed_c10d import * # noqa: F403
from .distributed_c10d import (
from .distributed_c10d import ( # pyrefly: ignore # deprecated
_all_gather_base,
_coalescing_manager,
_CoalescingManager,

View File

@ -1176,7 +1176,7 @@ def all_gather_inplace(
return tensor_list
from torch.distributed.distributed_c10d import (
from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated
_all_gather_base as legacy_all_gather_base,
_reduce_scatter_base as legacy_reduce_scatter_base,
all_gather as legacy_all_gather,

View File

@ -375,6 +375,7 @@ class LocalTensor(torch.Tensor):
def __repr__(self) -> str: # type: ignore[override]
parts = []
for k, v in self._local_tensors.items():
# pyrefly: ignore # bad-argument-type
parts.append(f" {k}: {v}")
tensors_str = ",\n".join(parts)
return f"LocalTensor(\n{tensors_str}\n)"
@ -638,6 +639,7 @@ class LocalTensorMode(TorchDispatchMode):
def _unpatch_device_mesh(self) -> None:
assert self._old_get_coordinate is not None
DeviceMesh.get_coordinate = self._old_get_coordinate
# pyrefly: ignore # bad-assignment
self._old_get_coordinate = None

View File

@ -316,6 +316,7 @@ def _local_all_gather_(
assert len(input_tensors) == 1
input_tensor = input_tensors[0]
# pyrefly: ignore # bad-assignment
output_tensors = output_tensors[0]
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
@ -333,10 +334,12 @@ def _local_all_gather_(
# For each rank in the group, gather from their input tensor
for i, rank_i in enumerate(group_ranks):
# pyrefly: ignore # missing-attribute
output_tensors[i].copy_(input_tensor._local_tensors[rank_i])
work = FakeWork()
work_so = Work.boxed(work)
# pyrefly: ignore # bad-return
return ([output_tensors], work_so)
@ -423,6 +426,7 @@ def _local_scatter_(
assert len(output_tensors) == 1
assert len(input_tensors) == 1
output_tensor = output_tensors[0]
# pyrefly: ignore # bad-assignment
input_tensors = input_tensors[0]
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)

View File

@ -89,6 +89,7 @@ class DTensorSpec:
if not isinstance(self.placements, tuple):
self.placements = tuple(self.placements)
if self.shard_order is None:
# pyrefly: ignore # bad-assignment
self.shard_order = DTensorSpec.compute_default_shard_order(self.placements)
self._hash: int | None = None

View File

@ -700,6 +700,7 @@ def _restore_state_dict(
# Replace state dict attr names with the fqn
for name, _ in chain(
original_module.named_parameters(remove_duplicate=False),
# pyrefly: ignore # bad-argument-type
original_module.named_buffers(remove_duplicate=False),
):
if name in param_buffer_table_reverse:

View File

@ -218,6 +218,7 @@ class FlexKernelOptions(TypedDict, total=False):
waves_per_eu: NotRequired[int]
"""ROCm-specific waves per execution unit."""
# pyrefly: ignore # invalid-annotation
force_flash: NotRequired[bool]
""" If True, forces use of the cute-dsl flash attention kernel.

View File

@ -1,5 +1,5 @@
from . import parametrizations, parametrize, rnn, stateless
from .clip_grad import (
from .clip_grad import ( # pyrefly: ignore # deprecated
_clip_grads_with_norm_ as clip_grads_with_norm_,
_get_total_norm as get_total_norm,
clip_grad_norm,

View File

@ -283,6 +283,7 @@ def clip_grad_value_(
clip_value = float(clip_value)
grads = [p.grad for p in parameters if p.grad is not None]
# pyrefly: ignore # bad-argument-type
grouped_grads = _group_tensors_by_device_and_dtype([grads])
for (device, _), ([grads], _) in grouped_grads.items():

View File

@ -111,8 +111,10 @@ class _Orthogonal(Module):
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
if hasattr(self, "base"):
# pyrefly: ignore # unbound-name
Q = self.base @ Q
if transposed:
# pyrefly: ignore # unbound-name
Q = Q.mT
return Q # type: ignore[possibly-undefined]

View File

@ -297,6 +297,7 @@ class AveragedModel(Module):
avg_fn = get_swa_avg_fn()
n_averaged = self.n_averaged.to(device)
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
# pyrefly: ignore # missing-attribute
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
else:
for p_averaged, p_model in zip( # type: ignore[assignment]

View File

@ -71,6 +71,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
nrows // 16, 16
)
).view(-1)
# pyrefly: ignore # unbound-name
outp = outp.index_copy(1, cols_permuted, outp)
# interleave_column_major_tensor

View File

@ -67,6 +67,7 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
# Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
# pyrefly: ignore # no-matching-overload
return self.__class__(
torch.Size([self.shape[-1], self.shape[0]]),
packed=self.packed_t,

View File

@ -184,6 +184,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
outer_stride,
) -> torch.Tensor:
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
# pyrefly: ignore # no-matching-overload
return cls(
shape=shape,
packed=inner_tensors.get("packed", None),
@ -413,6 +414,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
sparse_tensor_cutlass,
meta_tensor_cutlass,
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
# pyrefly: ignore # no-matching-overload
return cls(
original_tensor.shape,
packed=sparse_tensor_cutlass,
@ -499,6 +501,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
original_tensor, algorithm=algorithm, use_cutlass=True
)
# pyrefly: ignore # no-matching-overload
return cls(
original_tensor.shape,
packed=packed,
@ -560,6 +563,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUSPARSELT":
cls._validate_device_dim_dtype_shape(original_tensor)
# pyrefly: ignore # no-matching-overload
return cls(
shape=original_tensor.shape,
packed=torch._cslt_compress(original_tensor),
@ -626,6 +630,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
packed = packed.view(original_tensor.shape[0], -1)
packed_t = packed_t.view(original_tensor.shape[1], -1)
# pyrefly: ignore # no-matching-overload
return cls(
original_tensor.shape,
packed=packed,

View File

@ -1330,6 +1330,7 @@ class Identity(sympy.Function):
def _sympystr(self, printer):
"""Controls how sympy's StrPrinter prints this"""
# pyrefly: ignore # missing-attribute
return f"({printer.doprint(self.args[0])})"
def _eval_is_real(self):