mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[RFC] Add pyrefly to lintrunner (#165179)
This will add pyrefly to lint runner as a warning only - and allow us to collect feedback about the tool before switching to pyrefly as the main type checker. References the steps outlined here: : https://github.com/pytorch/pytorch/issues/163283: test plan: `lintrunner init` `lintrunner` confirm when pyrefly errors are present results look like: https://gist.github.com/maggiemoss/e6cb2d015dd1ded560ae1329098cf33f Pull Request resolved: https://github.com/pytorch/pytorch/pull/165179 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7df9aca529
commit
d795fb225a
4
.github/workflows/lint.yml
vendored
4
.github/workflows/lint.yml
vendored
@ -118,9 +118,9 @@ jobs:
|
|||||||
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
||||||
echo "Running all other linters"
|
echo "Running all other linters"
|
||||||
if [ "$CHANGED_FILES" = '*' ]; then
|
if [ "$CHANGED_FILES" = '*' ]; then
|
||||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
|
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
|
||||||
else
|
else
|
||||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
||||||
fi
|
fi
|
||||||
|
|
||||||
quick-checks:
|
quick-checks:
|
||||||
|
@ -209,6 +209,46 @@ command = [
|
|||||||
'@{{PATHSFILE}}'
|
'@{{PATHSFILE}}'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[[linter]]
|
||||||
|
code = 'PYREFLY'
|
||||||
|
include_patterns = [
|
||||||
|
'torch/**/*.py',
|
||||||
|
'torch/**/*.pyi',
|
||||||
|
'torchgen/**/*.py',
|
||||||
|
'torchgen/**/*.pyi',
|
||||||
|
'functorch/**/*.py',
|
||||||
|
'functorch/**/*.pyi',
|
||||||
|
]
|
||||||
|
exclude_patterns = []
|
||||||
|
command = [
|
||||||
|
'python3',
|
||||||
|
'tools/linter/adapters/pyrefly_linter.py',
|
||||||
|
'--config=pyrefly.toml',
|
||||||
|
]
|
||||||
|
init_command = [
|
||||||
|
'python3',
|
||||||
|
'tools/linter/adapters/pip_init.py',
|
||||||
|
'--dry-run={{DRYRUN}}',
|
||||||
|
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||||
|
'expecttest==0.3.0',
|
||||||
|
'pyrefly==0.36.2',
|
||||||
|
'sympy==1.13.3',
|
||||||
|
'types-requests==2.27.25',
|
||||||
|
'types-pyyaml==6.0.2',
|
||||||
|
'types-tabulate==0.8.8',
|
||||||
|
'types-protobuf==5.29.1.20250403',
|
||||||
|
'types-setuptools==79.0.0.20250422',
|
||||||
|
'types-jinja2==2.11.9',
|
||||||
|
'types-colorama==0.4.6',
|
||||||
|
'filelock==3.18.0',
|
||||||
|
'junitparser==2.1.1',
|
||||||
|
'rich==14.1.0',
|
||||||
|
'optree==0.17.0',
|
||||||
|
'types-openpyxl==3.1.5.20250919',
|
||||||
|
'types-python-dateutil==2.9.0.20251008'
|
||||||
|
]
|
||||||
|
|
||||||
[[linter]]
|
[[linter]]
|
||||||
code = 'CLANGTIDY'
|
code = 'CLANGTIDY'
|
||||||
include_patterns = [
|
include_patterns = [
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# A Pyrefly configuration for PyTorch
|
# A Pyrefly configuration for PyTorch
|
||||||
# Based on https://github.com/pytorch/pytorch/blob/main/mypy.ini
|
# Based on https://github.com/pytorch/pytorch/blob/main/mypy.ini
|
||||||
|
python-version = "3.12"
|
||||||
|
|
||||||
project-includes = [
|
project-includes = [
|
||||||
"torch",
|
"torch",
|
||||||
"caffe2",
|
"caffe2",
|
||||||
@ -36,6 +38,7 @@ project-excludes = [
|
|||||||
"torch/nn/modules/rnn.py", # only remove when parsing errors are fixed
|
"torch/nn/modules/rnn.py", # only remove when parsing errors are fixed
|
||||||
"torch/_inductor/codecache.py",
|
"torch/_inductor/codecache.py",
|
||||||
"torch/distributed/elastic/metrics/__init__.py",
|
"torch/distributed/elastic/metrics/__init__.py",
|
||||||
|
"torch/_inductor/fx_passes/bucketing.py",
|
||||||
# ====
|
# ====
|
||||||
"benchmarks/instruction_counts/main.py",
|
"benchmarks/instruction_counts/main.py",
|
||||||
"benchmarks/instruction_counts/definitions/setup.py",
|
"benchmarks/instruction_counts/definitions/setup.py",
|
||||||
|
258
tools/linter/adapters/pyrefly_linter.py
Normal file
258
tools/linter/adapters/pyrefly_linter.py
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from enum import Enum
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
|
class LintSeverity(str, Enum):
|
||||||
|
ERROR = "error"
|
||||||
|
WARNING = "warning"
|
||||||
|
ADVICE = "advice"
|
||||||
|
DISABLED = "disabled"
|
||||||
|
|
||||||
|
|
||||||
|
class LintMessage(NamedTuple):
|
||||||
|
path: str | None
|
||||||
|
line: int | None
|
||||||
|
char: int | None
|
||||||
|
code: str
|
||||||
|
severity: LintSeverity
|
||||||
|
name: str
|
||||||
|
original: str | None
|
||||||
|
replacement: str | None
|
||||||
|
description: str | None
|
||||||
|
|
||||||
|
|
||||||
|
# Note: This regex pattern is kept for reference but not used for pyrefly JSON parsing
|
||||||
|
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||||
|
r"""(?mx)
|
||||||
|
^
|
||||||
|
(?P<file>.*?):
|
||||||
|
(?P<line>\d+):
|
||||||
|
(?:(?P<column>-?\d+):)?
|
||||||
|
\s(?P<severity>\S+?):?
|
||||||
|
\s(?P<message>.*)
|
||||||
|
\s(?P<code>\[.*\])
|
||||||
|
$
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
|
||||||
|
INTERNAL_ERROR_RE: re.Pattern[str] = re.compile(
|
||||||
|
r"""(?mx)
|
||||||
|
^
|
||||||
|
(?P<file>.*?):
|
||||||
|
(?P<line>\d+):
|
||||||
|
\s(?P<severity>\S+?):?
|
||||||
|
\s(?P<message>INTERNAL\sERROR.*)
|
||||||
|
$
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(
|
||||||
|
args: list[str],
|
||||||
|
*,
|
||||||
|
extra_env: dict[str, str] | None,
|
||||||
|
retries: int,
|
||||||
|
) -> subprocess.CompletedProcess[bytes]:
|
||||||
|
logging.debug("$ %s", " ".join(args))
|
||||||
|
start_time = time.monotonic()
|
||||||
|
try:
|
||||||
|
return subprocess.run(
|
||||||
|
args,
|
||||||
|
capture_output=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
end_time = time.monotonic()
|
||||||
|
logging.debug("took %dms", (end_time - start_time) * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
# Severity mapping (currently only used for stderr internal errors)
|
||||||
|
# Pyrefly JSON output doesn't include severity, so all errors default to ERROR
|
||||||
|
severities = {
|
||||||
|
"error": LintSeverity.ERROR,
|
||||||
|
"note": LintSeverity.ADVICE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def check_pyrefly_installed(code: str) -> list[LintMessage]:
|
||||||
|
cmd = ["pyrefly", "--version"]
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, check=True, capture_output=True)
|
||||||
|
return []
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
msg = e.stderr.decode(errors="replace")
|
||||||
|
return [
|
||||||
|
LintMessage(
|
||||||
|
path=None,
|
||||||
|
line=None,
|
||||||
|
char=None,
|
||||||
|
code=code,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="command-failed",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=f"Could not run '{' '.join(cmd)}': {msg}",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def in_github_actions() -> bool:
|
||||||
|
return bool(os.getenv("GITHUB_ACTIONS"))
|
||||||
|
|
||||||
|
|
||||||
|
def check_files(
|
||||||
|
code: str,
|
||||||
|
config: str,
|
||||||
|
) -> list[LintMessage]:
|
||||||
|
try:
|
||||||
|
pyrefly_commands = [
|
||||||
|
"pyrefly",
|
||||||
|
"check",
|
||||||
|
"--config",
|
||||||
|
config,
|
||||||
|
"--output-format=json",
|
||||||
|
]
|
||||||
|
proc = run_command(
|
||||||
|
[*pyrefly_commands],
|
||||||
|
extra_env={},
|
||||||
|
retries=0,
|
||||||
|
)
|
||||||
|
except OSError as err:
|
||||||
|
return [
|
||||||
|
LintMessage(
|
||||||
|
path=None,
|
||||||
|
line=None,
|
||||||
|
char=None,
|
||||||
|
code=code,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="command-failed",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=(f"Failed due to {err.__class__.__name__}:\n{err}"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
stdout = str(proc.stdout, "utf-8").strip()
|
||||||
|
stderr = str(proc.stderr, "utf-8").strip()
|
||||||
|
if proc.returncode not in (0, 1):
|
||||||
|
return [
|
||||||
|
LintMessage(
|
||||||
|
path=None,
|
||||||
|
line=None,
|
||||||
|
char=None,
|
||||||
|
code=code,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="command-failed",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=stderr,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Parse JSON output from pyrefly
|
||||||
|
try:
|
||||||
|
if stdout:
|
||||||
|
result = json.loads(stdout)
|
||||||
|
errors = result.get("errors", [])
|
||||||
|
else:
|
||||||
|
errors = []
|
||||||
|
# For now filter out deprecated warnings and only report type errors as warnings
|
||||||
|
# until we remove mypy
|
||||||
|
errors = [error for error in errors if error["name"] != "deprecated"]
|
||||||
|
rc = [
|
||||||
|
LintMessage(
|
||||||
|
path=error["path"],
|
||||||
|
name=error["name"],
|
||||||
|
description=error.get(
|
||||||
|
"description", error.get("concise_description", "")
|
||||||
|
),
|
||||||
|
line=error["line"],
|
||||||
|
char=error["column"],
|
||||||
|
code=code,
|
||||||
|
severity=LintSeverity.ADVICE,
|
||||||
|
# uncomment and replace when we switch to pyrefly
|
||||||
|
# severity=LintSeverity.ADVICE if error["name"] == "deprecated" else LintSeverity.ERROR,
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
)
|
||||||
|
for error in errors
|
||||||
|
]
|
||||||
|
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||||
|
return [
|
||||||
|
LintMessage(
|
||||||
|
path=None,
|
||||||
|
line=None,
|
||||||
|
char=None,
|
||||||
|
code=code,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="json-parse-error",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=f"Failed to parse pyrefly JSON output: {e}",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Still check stderr for internal errors
|
||||||
|
rc += [
|
||||||
|
LintMessage(
|
||||||
|
path=match["file"],
|
||||||
|
name="INTERNAL ERROR",
|
||||||
|
description=match["message"],
|
||||||
|
line=int(match["line"]),
|
||||||
|
char=None,
|
||||||
|
code=code,
|
||||||
|
severity=severities.get(match["severity"], LintSeverity.ERROR),
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
)
|
||||||
|
for match in INTERNAL_ERROR_RE.finditer(stderr)
|
||||||
|
]
|
||||||
|
return rc
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="pyrefly wrapper linter.",
|
||||||
|
fromfile_prefix_chars="@",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--code",
|
||||||
|
default="PYREFLY",
|
||||||
|
help="the code this lint should report as",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="verbose logging",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
required=True,
|
||||||
|
help="path to an mypy .ini config file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="<%(threadName)s:%(levelname)s> %(message)s",
|
||||||
|
level=logging.INFO,
|
||||||
|
stream=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
lint_messages = check_pyrefly_installed(args.code) + check_files(
|
||||||
|
args.code, args.config
|
||||||
|
)
|
||||||
|
for lint_message in lint_messages:
|
||||||
|
print(json.dumps(lint_message._asdict()), flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -76,6 +76,7 @@ class StreamVariable(VariableTracker):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
self.value = value
|
self.value = value
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def python_type(self) -> type:
|
def python_type(self) -> type:
|
||||||
|
@ -1492,6 +1492,7 @@ def _aot_stage2a_partition(
|
|||||||
|
|
||||||
# apply joint_gm callback here
|
# apply joint_gm callback here
|
||||||
if callable(torch._functorch.config.joint_custom_pass):
|
if callable(torch._functorch.config.joint_custom_pass):
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
|
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
|
||||||
|
|
||||||
static_lifetime_input_indices = fw_metadata.static_input_indices
|
static_lifetime_input_indices = fw_metadata.static_input_indices
|
||||||
@ -1761,6 +1762,7 @@ def _aot_stage2b_bw_compile(
|
|||||||
# tensor which is wrong.
|
# tensor which is wrong.
|
||||||
|
|
||||||
ph_size = ph_arg.size()
|
ph_size = ph_arg.size()
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if len(ph_size) == 0 and len(real_stride) > 0:
|
if len(ph_size) == 0 and len(real_stride) > 0:
|
||||||
# Fix for 0-dimensional tensors: When a tensor becomes 0-d
|
# Fix for 0-dimensional tensors: When a tensor becomes 0-d
|
||||||
# (e.g., via squeeze), its stride should be () not (1,).
|
# (e.g., via squeeze), its stride should be () not (1,).
|
||||||
|
@ -628,6 +628,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
|||||||
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
|
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
|
||||||
]
|
]
|
||||||
# add the scale nodes to the output find the first sym_node in the output
|
# add the scale nodes to the output find the first sym_node in the output
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
idx = find_first_sym_node(output_updated_args)
|
idx = find_first_sym_node(output_updated_args)
|
||||||
scale_nodes = tensor_scale_nodes + sym_scale_nodes
|
scale_nodes = tensor_scale_nodes + sym_scale_nodes
|
||||||
if scale_nodes:
|
if scale_nodes:
|
||||||
|
@ -86,7 +86,7 @@ def bucket_all_gather(
|
|||||||
mode: BucketMode = "default",
|
mode: BucketMode = "default",
|
||||||
) -> None:
|
) -> None:
|
||||||
if bucket_cap_mb_by_bucket_idx is None:
|
if bucket_cap_mb_by_bucket_idx is None:
|
||||||
from torch._inductor.fx_passes.bucketing import (
|
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
|
||||||
bucket_cap_mb_by_bucket_idx_default,
|
bucket_cap_mb_by_bucket_idx_default,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -103,7 +103,7 @@ def bucket_reduce_scatter(
|
|||||||
mode: BucketMode = "default",
|
mode: BucketMode = "default",
|
||||||
) -> None:
|
) -> None:
|
||||||
if bucket_cap_mb_by_bucket_idx is None:
|
if bucket_cap_mb_by_bucket_idx is None:
|
||||||
from torch._inductor.fx_passes.bucketing import (
|
from torch._inductor.fx_passes.bucketing import ( # pyrefly: ignore # missing-module-attribute
|
||||||
bucket_cap_mb_by_bucket_idx_default,
|
bucket_cap_mb_by_bucket_idx_default,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -209,6 +209,7 @@ def addmm_patterns_init():
|
|||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
int8_woq_fusion_replacement,
|
int8_woq_fusion_replacement,
|
||||||
[val(), val(), val(), val(), scale(), scale(), scale()],
|
[val(), val(), val(), val(), scale(), scale(), scale()],
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
fwd_only,
|
fwd_only,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
pass_patterns[0],
|
pass_patterns[0],
|
||||||
@ -230,6 +231,7 @@ def addmm_patterns_init():
|
|||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
matmul_replacement,
|
matmul_replacement,
|
||||||
[val(), val(), val(), val()],
|
[val(), val(), val(), val()],
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
fwd_only,
|
fwd_only,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
pass_patterns[0],
|
pass_patterns[0],
|
||||||
@ -251,6 +253,7 @@ def addmm_patterns_init():
|
|||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
matmul_replacement_two,
|
matmul_replacement_two,
|
||||||
[val(), val(), val()],
|
[val(), val(), val()],
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
fwd_only,
|
fwd_only,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
pass_patterns[0],
|
pass_patterns[0],
|
||||||
@ -276,6 +279,7 @@ def addmm_patterns_init():
|
|||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
addmm_fuse_replacement_second,
|
addmm_fuse_replacement_second,
|
||||||
[val() for _ in range(7)],
|
[val() for _ in range(7)],
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
fwd_only,
|
fwd_only,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
pass_patterns[0],
|
pass_patterns[0],
|
||||||
|
@ -49,6 +49,7 @@ def _misc_patterns_init():
|
|||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
randperm_index_add_replacement,
|
randperm_index_add_replacement,
|
||||||
[torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
|
[torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)],
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
fwd_only,
|
fwd_only,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
[post_grad_patterns, joint_graph_patterns],
|
[post_grad_patterns, joint_graph_patterns],
|
||||||
@ -68,6 +69,7 @@ def _misc_patterns_init():
|
|||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
randperm_index_replacement,
|
randperm_index_replacement,
|
||||||
[torch.empty(4, 8, device=device)],
|
[torch.empty(4, 8, device=device)],
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
fwd_only,
|
fwd_only,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
[post_grad_patterns, joint_graph_patterns],
|
[post_grad_patterns, joint_graph_patterns],
|
||||||
|
@ -919,6 +919,7 @@ def _pad_mm_init() -> None:
|
|||||||
pattern,
|
pattern,
|
||||||
replacement,
|
replacement,
|
||||||
args,
|
args,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
joint_fwd_bwd,
|
joint_fwd_bwd,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
patterns,
|
patterns,
|
||||||
@ -931,6 +932,7 @@ def _pad_mm_init() -> None:
|
|||||||
pattern,
|
pattern,
|
||||||
replacement,
|
replacement,
|
||||||
args,
|
args,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
fwd_only,
|
fwd_only,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
patterns,
|
patterns,
|
||||||
|
@ -666,6 +666,7 @@ def lazy_init():
|
|||||||
prepare_softmax_replacement,
|
prepare_softmax_replacement,
|
||||||
[torch.empty(4, 8)],
|
[torch.empty(4, 8)],
|
||||||
scalar_workaround=dict(dim=-1),
|
scalar_workaround=dict(dim=-1),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
trace_fn=fwd_only,
|
trace_fn=fwd_only,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
pass_dicts=pass_patterns[1],
|
pass_dicts=pass_patterns[1],
|
||||||
@ -730,6 +731,7 @@ def register_lowering_pattern(
|
|||||||
return pattern_matcher.register_lowering_pattern(
|
return pattern_matcher.register_lowering_pattern(
|
||||||
pattern,
|
pattern,
|
||||||
extra_check,
|
extra_check,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
pass_dict=pass_patterns[pass_number],
|
pass_dict=pass_patterns[pass_number],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1573,6 +1575,7 @@ def register_partial_reduction_pattern():
|
|||||||
|
|
||||||
@register_graph_pattern(
|
@register_graph_pattern(
|
||||||
MultiOutputPattern([partial_reduc, full_reduc]),
|
MultiOutputPattern([partial_reduc, full_reduc]),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
pass_dict=pass_patterns[2],
|
pass_dict=pass_patterns[2],
|
||||||
)
|
)
|
||||||
def reuse_partial(match, input, reduced_dims, keepdim):
|
def reuse_partial(match, input, reduced_dims, keepdim):
|
||||||
|
@ -27,7 +27,7 @@ from torch._dynamo.utils import counters
|
|||||||
from torch._higher_order_ops.associative_scan import associative_scan_op
|
from torch._higher_order_ops.associative_scan import associative_scan_op
|
||||||
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
||||||
from torch._library.utils import get_layout_constraint_tag
|
from torch._library.utils import get_layout_constraint_tag
|
||||||
from torch._prims_common import (
|
from torch._prims_common import ( # pyrefly: ignore # deprecated
|
||||||
canonicalize_dim,
|
canonicalize_dim,
|
||||||
canonicalize_dims,
|
canonicalize_dims,
|
||||||
check,
|
check,
|
||||||
|
@ -173,6 +173,7 @@ def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
|
|||||||
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
|
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
|
||||||
)
|
)
|
||||||
_OPAQUE_TYPES[cls] = name
|
_OPAQUE_TYPES[cls] = name
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
torch._C._register_opaque_type(name)
|
torch._C._register_opaque_type(name)
|
||||||
|
|
||||||
|
|
||||||
@ -182,4 +183,5 @@ def is_opaque_type(cls: Any) -> bool:
|
|||||||
"""
|
"""
|
||||||
if cls not in _OPAQUE_TYPES:
|
if cls not in _OPAQUE_TYPES:
|
||||||
return False
|
return False
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])
|
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])
|
||||||
|
@ -135,7 +135,7 @@ if is_available():
|
|||||||
# this.
|
# this.
|
||||||
# pyrefly: ignore # deprecated
|
# pyrefly: ignore # deprecated
|
||||||
from .distributed_c10d import * # noqa: F403
|
from .distributed_c10d import * # noqa: F403
|
||||||
from .distributed_c10d import (
|
from .distributed_c10d import ( # pyrefly: ignore # deprecated
|
||||||
_all_gather_base,
|
_all_gather_base,
|
||||||
_coalescing_manager,
|
_coalescing_manager,
|
||||||
_CoalescingManager,
|
_CoalescingManager,
|
||||||
|
@ -1009,8 +1009,8 @@ lib_impl.impl("broadcast", _broadcast_meta, "Meta")
|
|||||||
lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
|
lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
|
||||||
|
|
||||||
# mark these ops has side effect so that they won't be removed by DCE
|
# mark these ops has side effect so that they won't be removed by DCE
|
||||||
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default)
|
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) # type: ignore[has-type]
|
||||||
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor)
|
torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) # type: ignore[has-type]
|
||||||
|
|
||||||
# Register legacy ops for backward compatibility
|
# Register legacy ops for backward compatibility
|
||||||
# TODO(yifu): remove these in functional collective beta release
|
# TODO(yifu): remove these in functional collective beta release
|
||||||
@ -1176,7 +1176,7 @@ def all_gather_inplace(
|
|||||||
return tensor_list
|
return tensor_list
|
||||||
|
|
||||||
|
|
||||||
from torch.distributed.distributed_c10d import (
|
from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated
|
||||||
_all_gather_base as legacy_all_gather_base,
|
_all_gather_base as legacy_all_gather_base,
|
||||||
_reduce_scatter_base as legacy_reduce_scatter_base,
|
_reduce_scatter_base as legacy_reduce_scatter_base,
|
||||||
all_gather as legacy_all_gather,
|
all_gather as legacy_all_gather,
|
||||||
@ -1190,11 +1190,11 @@ from torch.distributed.distributed_c10d import (
|
|||||||
# This dict should contain sets of functions that dynamo is allowed to remap.
|
# This dict should contain sets of functions that dynamo is allowed to remap.
|
||||||
# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
|
# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
|
||||||
traceable_collective_remaps = {
|
traceable_collective_remaps = {
|
||||||
legacy_allgather: all_gather_tensor_inplace,
|
legacy_allgather: all_gather_tensor_inplace, # type: ignore[has-type]
|
||||||
legacy_reducescatter: reduce_scatter_tensor_inplace,
|
legacy_reducescatter: reduce_scatter_tensor_inplace, # type: ignore[has-type]
|
||||||
legacy_allreduce: all_reduce_inplace,
|
legacy_allreduce: all_reduce_inplace, # type: ignore[has-type]
|
||||||
legacy_all_to_all_single: all_to_all_inplace,
|
legacy_all_to_all_single: all_to_all_inplace, # type: ignore[has-type]
|
||||||
legacy_all_gather: all_gather_inplace,
|
legacy_all_gather: all_gather_inplace, # type: ignore[has-type]
|
||||||
legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,
|
legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, # type: ignore[has-type]
|
||||||
legacy_all_gather_base: all_gather_tensor_inplace,
|
legacy_all_gather_base: all_gather_tensor_inplace, # type: ignore[has-type]
|
||||||
}
|
}
|
||||||
|
@ -393,6 +393,7 @@ class LocalTensor(torch.Tensor):
|
|||||||
def __repr__(self) -> str: # type: ignore[override]
|
def __repr__(self) -> str: # type: ignore[override]
|
||||||
parts = []
|
parts = []
|
||||||
for k, v in self._local_tensors.items():
|
for k, v in self._local_tensors.items():
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
parts.append(f" {k}: {v}")
|
parts.append(f" {k}: {v}")
|
||||||
tensors_str = ",\n".join(parts)
|
tensors_str = ",\n".join(parts)
|
||||||
return f"LocalTensor(\n{tensors_str}\n)"
|
return f"LocalTensor(\n{tensors_str}\n)"
|
||||||
@ -680,6 +681,7 @@ class LocalTensorMode(TorchDispatchMode):
|
|||||||
def _unpatch_device_mesh(self) -> None:
|
def _unpatch_device_mesh(self) -> None:
|
||||||
assert self._old_get_coordinate is not None
|
assert self._old_get_coordinate is not None
|
||||||
DeviceMesh.get_coordinate = self._old_get_coordinate
|
DeviceMesh.get_coordinate = self._old_get_coordinate
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self._old_get_coordinate = None
|
self._old_get_coordinate = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -316,6 +316,7 @@ def _local_all_gather_(
|
|||||||
assert len(input_tensors) == 1
|
assert len(input_tensors) == 1
|
||||||
|
|
||||||
input_tensor = input_tensors[0]
|
input_tensor = input_tensors[0]
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
output_tensors = output_tensors[0]
|
output_tensors = output_tensors[0]
|
||||||
|
|
||||||
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
|
||||||
@ -336,10 +337,12 @@ def _local_all_gather_(
|
|||||||
source_tensor = input_tensor
|
source_tensor = input_tensor
|
||||||
if isinstance(input_tensor, LocalTensor):
|
if isinstance(input_tensor, LocalTensor):
|
||||||
source_tensor = input_tensor._local_tensors[rank_i]
|
source_tensor = input_tensor._local_tensors[rank_i]
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
output_tensors[i].copy_(source_tensor)
|
output_tensors[i].copy_(source_tensor)
|
||||||
|
|
||||||
work = FakeWork()
|
work = FakeWork()
|
||||||
work_so = Work.boxed(work)
|
work_so = Work.boxed(work)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return ([output_tensors], work_so)
|
return ([output_tensors], work_so)
|
||||||
|
|
||||||
|
|
||||||
@ -426,6 +429,7 @@ def _local_scatter_(
|
|||||||
assert len(output_tensors) == 1
|
assert len(output_tensors) == 1
|
||||||
assert len(input_tensors) == 1
|
assert len(input_tensors) == 1
|
||||||
output_tensor = output_tensors[0]
|
output_tensor = output_tensors[0]
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
input_tensors = input_tensors[0]
|
input_tensors = input_tensors[0]
|
||||||
|
|
||||||
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
|
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)
|
||||||
|
@ -90,6 +90,7 @@ class DTensorSpec:
|
|||||||
if not isinstance(self.placements, tuple):
|
if not isinstance(self.placements, tuple):
|
||||||
self.placements = tuple(self.placements)
|
self.placements = tuple(self.placements)
|
||||||
if self.shard_order is None:
|
if self.shard_order is None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.shard_order = DTensorSpec.compute_default_shard_order(self.placements)
|
self.shard_order = DTensorSpec.compute_default_shard_order(self.placements)
|
||||||
self._hash: int | None = None
|
self._hash: int | None = None
|
||||||
|
|
||||||
|
@ -701,6 +701,7 @@ def _restore_state_dict(
|
|||||||
for name, _ in list(
|
for name, _ in list(
|
||||||
chain(
|
chain(
|
||||||
original_module.named_parameters(remove_duplicate=False),
|
original_module.named_parameters(remove_duplicate=False),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
original_module.named_buffers(remove_duplicate=False),
|
original_module.named_buffers(remove_duplicate=False),
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
@ -218,6 +218,7 @@ class FlexKernelOptions(TypedDict, total=False):
|
|||||||
waves_per_eu: NotRequired[int]
|
waves_per_eu: NotRequired[int]
|
||||||
"""ROCm-specific waves per execution unit."""
|
"""ROCm-specific waves per execution unit."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
force_flash: NotRequired[bool]
|
force_flash: NotRequired[bool]
|
||||||
""" If True, forces use of the cute-dsl flash attention kernel.
|
""" If True, forces use of the cute-dsl flash attention kernel.
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from . import parametrizations, parametrize, rnn, stateless
|
from . import parametrizations, parametrize, rnn, stateless
|
||||||
from .clip_grad import (
|
from .clip_grad import ( # pyrefly: ignore # deprecated
|
||||||
_clip_grads_with_norm_ as clip_grads_with_norm_,
|
_clip_grads_with_norm_ as clip_grads_with_norm_,
|
||||||
_get_total_norm as get_total_norm,
|
_get_total_norm as get_total_norm,
|
||||||
clip_grad_norm,
|
clip_grad_norm,
|
||||||
|
@ -283,6 +283,7 @@ def clip_grad_value_(
|
|||||||
clip_value = float(clip_value)
|
clip_value = float(clip_value)
|
||||||
|
|
||||||
grads = [p.grad for p in parameters if p.grad is not None]
|
grads = [p.grad for p in parameters if p.grad is not None]
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
grouped_grads = _group_tensors_by_device_and_dtype([grads])
|
grouped_grads = _group_tensors_by_device_and_dtype([grads])
|
||||||
|
|
||||||
for (device, _), ([grads], _) in grouped_grads.items():
|
for (device, _), ([grads], _) in grouped_grads.items():
|
||||||
|
@ -111,8 +111,10 @@ class _Orthogonal(Module):
|
|||||||
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
|
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
|
||||||
|
|
||||||
if hasattr(self, "base"):
|
if hasattr(self, "base"):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
Q = self.base @ Q
|
Q = self.base @ Q
|
||||||
if transposed:
|
if transposed:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
Q = Q.mT
|
Q = Q.mT
|
||||||
return Q # type: ignore[possibly-undefined]
|
return Q # type: ignore[possibly-undefined]
|
||||||
|
|
||||||
|
@ -170,6 +170,7 @@ class TorchTensor(ir.Tensor):
|
|||||||
|
|
||||||
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
|
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
|
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
|
||||||
"with a tensor backed by real data using ONNXProgram.apply_weights() "
|
"with a tensor backed by real data using ONNXProgram.apply_weights() "
|
||||||
"or save the model without initializers by setting include_initializers=False."
|
"or save the model without initializers by setting include_initializers=False."
|
||||||
|
@ -297,6 +297,7 @@ class AveragedModel(Module):
|
|||||||
avg_fn = get_swa_avg_fn()
|
avg_fn = get_swa_avg_fn()
|
||||||
n_averaged = self.n_averaged.to(device)
|
n_averaged = self.n_averaged.to(device)
|
||||||
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
|
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
|
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
|
||||||
else:
|
else:
|
||||||
for p_averaged, p_model in zip( # type: ignore[assignment]
|
for p_averaged, p_model in zip( # type: ignore[assignment]
|
||||||
|
@ -71,6 +71,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
|
|||||||
nrows // 16, 16
|
nrows // 16, 16
|
||||||
)
|
)
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
outp = outp.index_copy(1, cols_permuted, outp)
|
outp = outp.index_copy(1, cols_permuted, outp)
|
||||||
|
|
||||||
# interleave_column_major_tensor
|
# interleave_column_major_tensor
|
||||||
|
@ -67,6 +67,7 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|||||||
# Because we cannot go from the compressed representation back to the dense representation currently,
|
# Because we cannot go from the compressed representation back to the dense representation currently,
|
||||||
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
||||||
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
torch.Size([self.shape[-1], self.shape[0]]),
|
torch.Size([self.shape[-1], self.shape[0]]),
|
||||||
packed=self.packed_t,
|
packed=self.packed_t,
|
||||||
|
@ -184,6 +184,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||||||
outer_stride,
|
outer_stride,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
|
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return cls(
|
return cls(
|
||||||
shape=shape,
|
shape=shape,
|
||||||
packed=inner_tensors.get("packed", None),
|
packed=inner_tensors.get("packed", None),
|
||||||
@ -413,6 +414,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
|||||||
sparse_tensor_cutlass,
|
sparse_tensor_cutlass,
|
||||||
meta_tensor_cutlass,
|
meta_tensor_cutlass,
|
||||||
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
|
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return cls(
|
return cls(
|
||||||
original_tensor.shape,
|
original_tensor.shape,
|
||||||
packed=sparse_tensor_cutlass,
|
packed=sparse_tensor_cutlass,
|
||||||
@ -499,6 +501,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
|||||||
original_tensor, algorithm=algorithm, use_cutlass=True
|
original_tensor, algorithm=algorithm, use_cutlass=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return cls(
|
return cls(
|
||||||
original_tensor.shape,
|
original_tensor.shape,
|
||||||
packed=packed,
|
packed=packed,
|
||||||
@ -560,6 +563,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
|||||||
cls, original_tensor: torch.Tensor
|
cls, original_tensor: torch.Tensor
|
||||||
) -> "SparseSemiStructuredTensorCUSPARSELT":
|
) -> "SparseSemiStructuredTensorCUSPARSELT":
|
||||||
cls._validate_device_dim_dtype_shape(original_tensor)
|
cls._validate_device_dim_dtype_shape(original_tensor)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return cls(
|
return cls(
|
||||||
shape=original_tensor.shape,
|
shape=original_tensor.shape,
|
||||||
packed=torch._cslt_compress(original_tensor),
|
packed=torch._cslt_compress(original_tensor),
|
||||||
@ -626,6 +630,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
|||||||
packed = packed.view(original_tensor.shape[0], -1)
|
packed = packed.view(original_tensor.shape[0], -1)
|
||||||
packed_t = packed_t.view(original_tensor.shape[1], -1)
|
packed_t = packed_t.view(original_tensor.shape[1], -1)
|
||||||
|
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return cls(
|
return cls(
|
||||||
original_tensor.shape,
|
original_tensor.shape,
|
||||||
packed=packed,
|
packed=packed,
|
||||||
|
@ -1336,6 +1336,7 @@ class Identity(sympy.Function):
|
|||||||
|
|
||||||
def _sympystr(self, printer):
|
def _sympystr(self, printer):
|
||||||
"""Controls how sympy's StrPrinter prints this"""
|
"""Controls how sympy's StrPrinter prints this"""
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return f"({printer.doprint(self.args[0])})"
|
return f"({printer.doprint(self.args[0])})"
|
||||||
|
|
||||||
def _eval_is_real(self):
|
def _eval_is_real(self):
|
||||||
|
Reference in New Issue
Block a user