mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Expand type checking to mypy strict files (#165697)
Expands Pyrefly type checking to check the files outlined in the mypy-strict.ini configuration file: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165697 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
9095a9dfae
commit
f02e3947f6
16
pyrefly.toml
16
pyrefly.toml
@ -5,6 +5,7 @@ python-version = "3.12"
|
||||
project-includes = [
|
||||
"torch",
|
||||
"caffe2",
|
||||
"tools",
|
||||
"test/test_bundled_images.py",
|
||||
"test/test_bundled_inputs.py",
|
||||
"test/test_complex.py",
|
||||
@ -24,8 +25,11 @@ project-excludes = [
|
||||
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
||||
"torch/_inductor/runtime",
|
||||
"torch/_inductor/codegen/triton.py",
|
||||
"tools/linter/adapters/test_device_bias_linter.py",
|
||||
"tools/code_analyzer/gen_operators_yaml.py",
|
||||
# formatting issues, will turn on after adjusting where suppressions can be
|
||||
# in import statements
|
||||
"tools/flight_recorder/components/types.py",
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
"torch/package/_package_pickler.py",
|
||||
@ -40,17 +44,6 @@ project-excludes = [
|
||||
"torch/distributed/elastic/metrics/__init__.py",
|
||||
"torch/_inductor/fx_passes/bucketing.py",
|
||||
# ====
|
||||
"benchmarks/instruction_counts/main.py",
|
||||
"benchmarks/instruction_counts/definitions/setup.py",
|
||||
"benchmarks/instruction_counts/applications/ci.py",
|
||||
"benchmarks/instruction_counts/core/api.py",
|
||||
"benchmarks/instruction_counts/core/expand.py",
|
||||
"benchmarks/instruction_counts/core/types.py",
|
||||
"benchmarks/instruction_counts/core/utils.py",
|
||||
"benchmarks/instruction_counts/definitions/standard.py",
|
||||
"benchmarks/instruction_counts/definitions/setup.py",
|
||||
"benchmarks/instruction_counts/execution/runner.py",
|
||||
"benchmarks/instruction_counts/execution/work.py",
|
||||
"torch/include/**",
|
||||
"torch/csrc/**",
|
||||
"torch/distributed/elastic/agent/server/api.py",
|
||||
@ -137,3 +130,4 @@ errors.bad-param-name-override = false
|
||||
errors.implicit-import = false
|
||||
permissive-ignores = true
|
||||
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
|
||||
search-path = ["tools/experimental"]
|
||||
|
@ -863,6 +863,7 @@ static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
saved_variables.append(f"{type.cpp_type()} {name};")
|
||||
|
||||
if type in MISC_GETTER_DEFS:
|
||||
# pyrefly: ignore # index-error
|
||||
getter_def, body = MISC_GETTER_DEFS[type]
|
||||
getter_definitions.append(
|
||||
getter_def.substitute(op=info.op, name=name, body=body)
|
||||
@ -1033,6 +1034,7 @@ static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
unpack_ivalues = []
|
||||
for typ, name in zip(apply_functional_args_ref_types, apply_functional_args):
|
||||
typ = typ.removesuffix("&")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();")
|
||||
|
||||
schema_args = [f"std::array<bool, {len(input_name_to_idx)}>"]
|
||||
|
@ -182,6 +182,7 @@ def format_trace_inputs(f: NativeFunction) -> str:
|
||||
ADD_TRACE_INPUT.substitute(
|
||||
name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
for i in range(num_out_args)
|
||||
]
|
||||
|
||||
|
@ -1495,6 +1495,7 @@ def emit_body(
|
||||
else:
|
||||
expr = f"SavedVariable({var}, {str(is_output).lower()})"
|
||||
if foreacharg is not None and "original_selfs" not in expr:
|
||||
# pyrefly: ignore # unbound-name
|
||||
expr = expr.replace(src_name, name_in_expr)
|
||||
elif (
|
||||
type == BaseCType(tensorListT)
|
||||
@ -1844,12 +1845,14 @@ def emit_body(
|
||||
)
|
||||
)
|
||||
cur_derivative_conditions.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
|
||||
req_inp=inp_name + "[i]"
|
||||
)
|
||||
)
|
||||
else:
|
||||
cur_derivative_conditions.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp_name)
|
||||
)
|
||||
|
||||
@ -1920,6 +1923,7 @@ def emit_body(
|
||||
unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
|
||||
inp_name="original_self",
|
||||
inp="original_self" + input_suffix,
|
||||
# pyrefly: ignore # unbound-name
|
||||
zeros_fn=zeros_fn,
|
||||
)
|
||||
unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
|
||||
|
@ -95,8 +95,11 @@ def add_view_copy_derivatives(
|
||||
else:
|
||||
break
|
||||
# prefer manually-defined derivatives if any
|
||||
# pyrefly: ignore # unbound-name
|
||||
if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
|
||||
# pyrefly: ignore # unbound-name
|
||||
assert fn_schema is not None
|
||||
# pyrefly: ignore # unbound-name
|
||||
view_infos[fn_schema] = view_copy_differentiability_infos
|
||||
|
||||
infos.update(view_infos)
|
||||
@ -398,6 +401,7 @@ def postprocess_forward_derivatives(
|
||||
for arg_name in all_arg_names:
|
||||
if arg_name in diff_arg_names:
|
||||
arg_name = arg_name + "_t"
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
new_args.append(arg_name)
|
||||
|
||||
# TODO we are trolling
|
||||
@ -938,6 +942,7 @@ def saved_variables(
|
||||
+ f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
|
||||
)
|
||||
for nctype in nctypes:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
name = (
|
||||
nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
|
||||
)
|
||||
@ -947,6 +952,7 @@ def saved_variables(
|
||||
|
||||
def repl(m: re.Match[str]) -> str:
|
||||
suffix: str = (
|
||||
# pyrefly: ignore # bad-assignment
|
||||
info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
|
||||
)
|
||||
expr: str = info["expr"](name) if "expr" in info else m.group(0)
|
||||
|
@ -67,6 +67,7 @@ def is_intrested_file(
|
||||
|
||||
# ignore files that are not belong to pytorch
|
||||
if platform == TestPlatform.OSS:
|
||||
# pyrefly: ignore # import-error
|
||||
from package.oss.utils import get_pytorch_folder
|
||||
|
||||
if not file_path.startswith(get_pytorch_folder()):
|
||||
|
@ -24,6 +24,7 @@ def report_download_progress(
|
||||
file_size: int,
|
||||
) -> None:
|
||||
if file_size != -1:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
percent = min(1, (chunk_number * chunk_size) / file_size)
|
||||
bar = "#" * int(64 * percent)
|
||||
sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")
|
||||
|
@ -105,8 +105,10 @@ def extract_info_from_keyword(source: str, kw: ast.keyword) -> Any:
|
||||
evaluated_context = []
|
||||
for value in kw.value.values:
|
||||
if isinstance(value, ast.FormattedValue):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
evaluated_context.append(f"{{{ast.unparse(value.value)}}}")
|
||||
elif isinstance(value, ast.Constant):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
evaluated_context.append(value.value)
|
||||
return "".join(evaluated_context)
|
||||
else:
|
||||
@ -152,6 +154,7 @@ def find_unimplemented_v2_calls(
|
||||
|
||||
for kw in node.keywords:
|
||||
if kw.arg in info:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
info[kw.arg] = extract_info_from_keyword(source, kw)
|
||||
|
||||
if info["gb_type"] is None:
|
||||
|
@ -296,6 +296,7 @@ def run_multi_process_fuzzer(
|
||||
)
|
||||
|
||||
def write_func(msg):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
pbar.write(msg)
|
||||
else:
|
||||
persist_print("Progress: (install tqdm for better progress bar)")
|
||||
|
@ -111,6 +111,7 @@ class ConstantOperator(Operator):
|
||||
]:
|
||||
# Clamp integer values to [0, 3] to avoid index overflow in multiplication
|
||||
# Even with multiplication, indices should stay in reasonable range
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
fill_value = max(0, min(3, abs(fill_value)))
|
||||
|
||||
tensor_creation = (
|
||||
|
@ -78,15 +78,22 @@ class JobConfig:
|
||||
def parse_args(
|
||||
self: "JobConfig", args: Optional[Sequence[str]]
|
||||
) -> argparse.Namespace:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
args = self.parser.parse_args(args)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if args.selected_ranks is not None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
assert args.just_print_entries, (
|
||||
"Not support selecting ranks without printing entries"
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if args.pg_filters is not None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
assert args.just_print_entries, (
|
||||
"Not support selecting pg filters without printing entries"
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if args.verbose:
|
||||
logger.set_log_level(logging.DEBUG)
|
||||
# pyrefly: ignore # bad-return
|
||||
return args
|
||||
|
@ -41,6 +41,7 @@ def format_frame(frame: dict[str, str]) -> str:
|
||||
def format_frames(frames: list[dict[str, str]]) -> str:
|
||||
formatted_frames = []
|
||||
for frame in frames:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
formatted_frames.append(format_frame(frame))
|
||||
return "\n".join(formatted_frames)
|
||||
|
||||
@ -695,6 +696,7 @@ def check_version(version_by_ranks: dict[str, str], version: str) -> None:
|
||||
|
||||
|
||||
def get_version_detail(version: str) -> tuple[int, int]:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
version = version.split(".")
|
||||
assert len(version) == 2, f"Invalid version {version}"
|
||||
major, minor = map(int, version)
|
||||
|
@ -40,11 +40,17 @@ from tools.flight_recorder.components.types import types
|
||||
|
||||
def main(args: Optional[Sequence[str]] = None) -> None:
|
||||
config = JobConfig()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
args = config.parse_args(args)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
assert args.trace_dir, "Trace directory trace_dir is required"
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
details, version = read_dir(args)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
db = build_db(details, args, version)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if args.output:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
with open(args.output, "wb") as f:
|
||||
pickle.dump((types, db), f)
|
||||
|
||||
|
@ -34,6 +34,7 @@ class TensorRepr(gdb.Command): # type: ignore[misc, no-any-unimported]
|
||||
on it.
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
__doc__ = textwrap.dedent(__doc__).strip()
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -118,6 +118,7 @@ def extract_filename(path: str, keep_ext: bool = True) -> Any:
|
||||
|
||||
|
||||
# https://gist.github.com/pypt/94d747fe5180851196eb
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
class UniqueKeyLoader(Loader):
|
||||
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
||||
if not isinstance(node, MappingNode):
|
||||
@ -233,6 +234,7 @@ def preprocess(
|
||||
last_indent = input_indent
|
||||
|
||||
while blank_lines != 0:
|
||||
# pyrefly: ignore # unbound-name
|
||||
python_lines.append(python_indent + "print(file=OUT_STREAM)")
|
||||
blank_lines -= 1
|
||||
|
||||
@ -667,6 +669,7 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
|
||||
" ",
|
||||
)
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
return shader_dispatch_str
|
||||
|
||||
|
||||
@ -681,15 +684,18 @@ def genCppFiles(
|
||||
name = getName(spvPath).replace("_spv", "")
|
||||
|
||||
sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
spv_bin_strs.append(spv_bin_str)
|
||||
|
||||
shader_info = getShaderInfo(srcPath)
|
||||
|
||||
register_shader_info_strs.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
generateShaderInfoStr(shader_info, name, sizeBytes)
|
||||
)
|
||||
|
||||
if shader_info.register_for is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
|
||||
|
||||
spv_bin_arrays = "\n".join(spv_bin_strs)
|
||||
|
@ -131,12 +131,14 @@ class ComputeCodegenUnboxedKernels:
|
||||
else:
|
||||
arg_cpp = f"c10::IValue({arg_default})"
|
||||
args_code.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
f"""c10::Argument("{arg.name}", nullptr, ::std::nullopt, {arg_cpp})"""
|
||||
)
|
||||
|
||||
returns = f.func.returns
|
||||
returns_code = []
|
||||
for ret in returns:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
returns_code.append(f"""c10::Argument("{ret.name if ret.name else ""}")""")
|
||||
return f"""
|
||||
// aten::{schema}
|
||||
|
@ -112,6 +112,7 @@ class FileLinter:
|
||||
first_results = None
|
||||
original = replacement = pf.contents
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
while True:
|
||||
try:
|
||||
results = sorted(self._lint(pf), key=LintResult.sort_key)
|
||||
|
@ -41,6 +41,7 @@ class LineWithSets:
|
||||
t = self.tokens[i]
|
||||
after = i < len(self.tokens) - 1 and self.tokens[i + 1]
|
||||
if t.string == "Set" and t.type == token.NAME:
|
||||
# pyrefly: ignore # bad-return
|
||||
return after and after.string == "[" and after.type == token.OP
|
||||
return (
|
||||
(t.string == "set" and t.type == token.NAME)
|
||||
|
@ -19,11 +19,13 @@ from typing import NamedTuple
|
||||
# PyTorch directory root
|
||||
def scm_root() -> str:
|
||||
path = os.path.abspath(os.getcwd())
|
||||
# pyrefly: ignore # bad-assignment
|
||||
while True:
|
||||
if os.path.exists(os.path.join(path, ".git")):
|
||||
return path
|
||||
if os.path.isdir(os.path.join(path, ".hg")):
|
||||
return path
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
n = len(path)
|
||||
path = os.path.dirname(path)
|
||||
if len(path) == n:
|
||||
|
@ -101,6 +101,7 @@ def check_dictionary(filename: str) -> list[LintMessage]:
|
||||
words_set = set(words)
|
||||
if len(words) != len(words_set):
|
||||
raise ValueError("The dictionary file contains duplicate entries.")
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
uncased_words = list(map(str.lower, words))
|
||||
if uncased_words != sorted(uncased_words):
|
||||
raise ValueError(
|
||||
|
@ -12,6 +12,7 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
import isort
|
||||
import usort
|
||||
|
||||
|
@ -55,6 +55,7 @@ def report_download_progress(
|
||||
Pretty printer for file download progress.
|
||||
"""
|
||||
if file_size != -1:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
percent = min(1, (chunk_number * chunk_size) / file_size)
|
||||
bar = "#" * int(64 * percent)
|
||||
sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")
|
||||
|
@ -15,7 +15,10 @@ import multiprocessing as mp
|
||||
from enum import Enum
|
||||
from typing import NamedTuple
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
import libcst as cst
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
import libcst.matchers as m
|
||||
|
||||
|
||||
|
@ -69,6 +69,7 @@ def print_lint_message(path: Path, job: dict[str, Any], sync_tag: str) -> None:
|
||||
|
||||
lint_message = LintMessage(
|
||||
path=str(path),
|
||||
# pyrefly: ignore # unbound-name
|
||||
line=line_number,
|
||||
char=None,
|
||||
code="WORKFLOWSYNC",
|
||||
|
@ -73,6 +73,7 @@ def get_selected_kernel_dtypes_code(
|
||||
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
|
||||
conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
|
||||
body_parts.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if_condition_template.substitute(
|
||||
kernel_tag_name=kernel_tag,
|
||||
dtype_checks=" || ".join(conditions),
|
||||
|
@ -311,6 +311,7 @@ class Venv:
|
||||
python=python,
|
||||
capture_output=True,
|
||||
).stdout
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
candidates = list(map(Path, filter(None, map(str.strip, output.splitlines()))))
|
||||
candidates = [p for p in candidates if p.is_dir() and p.name == "site-packages"]
|
||||
if not candidates:
|
||||
@ -480,6 +481,7 @@ class Venv:
|
||||
cmd = [str(python), *args]
|
||||
env = popen_kwargs.pop("env", None) or {}
|
||||
check = popen_kwargs.pop("check", True)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
check=check,
|
||||
@ -531,6 +533,7 @@ class Venv:
|
||||
cmd = [str(self.bindir / "uv"), *args]
|
||||
env = popen_kwargs.pop("env", None) or {}
|
||||
check = popen_kwargs.pop("check", True)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
check=check,
|
||||
@ -938,6 +941,7 @@ def _move_single(
|
||||
|
||||
def _copy_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
|
||||
for src in listing:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
|
||||
|
||||
|
||||
|
@ -118,6 +118,7 @@ def download_patch(pr_number: int, repo_url: str, download_dir: str) -> str:
|
||||
urllib.request.urlopen(patch_url) as response,
|
||||
open(patch_file, "wb") as out_file,
|
||||
):
|
||||
# pyrefly: ignore # bad-specialization
|
||||
shutil.copyfileobj(response, out_file)
|
||||
if not os.path.isfile(patch_file):
|
||||
print(f"Failed to download patch for PR #{pr_number}")
|
||||
|
@ -994,6 +994,7 @@ def add_docstr_to_hint(docstr: str, hint: str) -> str:
|
||||
hint = hint.removesuffix("...").rstrip() # remove "..."
|
||||
content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ")
|
||||
# Remove trailing whitespace on each line
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
|
||||
|
||||
# attribute or property
|
||||
|
@ -100,6 +100,7 @@ class CMake:
|
||||
if ver is not None:
|
||||
eprint(f"Found {cmd} ({command}) version: {ver}", end="")
|
||||
cmake_versions.append(f"{cmd}=={ver}")
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if ver >= CMAKE_MINIMUM_VERSION:
|
||||
eprint(f" (>={CMAKE_MINIMUM_VERSION})")
|
||||
valid_cmake_versions[cmd] = ver
|
||||
|
@ -31,7 +31,9 @@ def gen_linker_script(
|
||||
text_line_start = text_line_start[0]
|
||||
|
||||
# ensure that parent directory exists before writing
|
||||
# pyrefly: ignore # bad-assignment
|
||||
fout = Path(fout)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
fout.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(fout, "w") as f:
|
||||
|
@ -60,6 +60,7 @@ class SegmentGenerator:
|
||||
df[time_col_name] = pd.to_datetime(df[time_col_name], unit="s", utc=True)
|
||||
|
||||
# get unique cmd names
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
unique_cmds_df = pd.DataFrame(df[cmd_col_name].unique(), columns=[cmd_col_name])
|
||||
|
||||
# get all detected python cmds
|
||||
|
@ -7,6 +7,7 @@ import unittest
|
||||
from collections import defaultdict
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
from gen_operators_yaml import (
|
||||
fill_output,
|
||||
get_parser_options,
|
||||
@ -241,5 +242,6 @@ class GenOperatorsYAMLTest(unittest.TestCase):
|
||||
|
||||
fill_output(output, options)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for op_val in output["operators"].values():
|
||||
self.assertFalse(op_val["include_all_overloads"])
|
||||
|
@ -88,6 +88,7 @@ operators:
|
||||
self.assertTrue(selector2.is_operator_selected("aten::sub.int"))
|
||||
|
||||
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||
False,
|
||||
False,
|
||||
@ -103,6 +104,7 @@ operators:
|
||||
)
|
||||
|
||||
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||
True,
|
||||
False,
|
||||
@ -118,6 +120,7 @@ operators:
|
||||
)
|
||||
|
||||
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||
False,
|
||||
True,
|
||||
|
@ -83,7 +83,9 @@ def _rank_correlated_tests(
|
||||
) -> list[str]:
|
||||
# Find the tests failures that are correlated with the edited files.
|
||||
# Filter the list to only include tests we want to run.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
tests_to_run = set(tests_to_run)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
ratings = _get_ratings_for_tests(tests_to_run)
|
||||
prioritize = sorted(ratings, key=lambda x: -ratings[x])
|
||||
return prioritize
|
||||
|
@ -36,11 +36,13 @@ def concated_logs() -> str:
|
||||
for log_file in glob.glob(
|
||||
f"{REPO_ROOT}/test/test-reports/**/*.log", recursive=True
|
||||
):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
logs.append(f"=== {log_file} ===")
|
||||
with open(log_file) as f:
|
||||
# For every line, prefix with fake timestamp for log classifier
|
||||
for line in f:
|
||||
line = line.rstrip("\n") # Remove any trailing newline
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
logs.append(f"2020-01-01T00:00:00.0000000Z {line}")
|
||||
return "\n".join(logs)
|
||||
|
||||
|
@ -1739,6 +1739,7 @@ class KernelArgs:
|
||||
for outer, inner in chain(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.input_buffers.items(),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.output_buffers.items(),
|
||||
):
|
||||
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
|
||||
|
@ -1480,6 +1480,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
gemm_output_buffer = ir.Buffer(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
name=gemm_output_name,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
layout=template_buffer.layout,
|
||||
)
|
||||
current_input_buffer = gemm_output_buffer
|
||||
@ -1503,6 +1504,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
current_input_buffer = ir.Buffer(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
name=buffer_name,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
layout=template_buffer.layout,
|
||||
)
|
||||
|
||||
|
@ -824,6 +824,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||
call_args, arg_types = self.prepare_triton_wrapper_args(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
call_args,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
arg_types,
|
||||
)
|
||||
wrapper_name = f"call_{kernel_name}"
|
||||
|
@ -683,6 +683,7 @@ class MetalKernel(SIMDKernel):
|
||||
# pyrefly: ignore # missing-argument
|
||||
t
|
||||
for t in self.range_tree_nodes.values()
|
||||
# pyrefly: ignore # missing-argument
|
||||
if t.is_reduction
|
||||
)
|
||||
cmp_op = ">" if reduction_type == "argmax" else "<"
|
||||
@ -865,6 +866,7 @@ class MetalKernel(SIMDKernel):
|
||||
# pyrefly: ignore # missing-argument
|
||||
t.numel
|
||||
for t in self.range_trees
|
||||
# pyrefly: ignore # missing-argument
|
||||
if t.is_reduction
|
||||
)
|
||||
# If using dynamic shapes, set the threadgroup size to be the
|
||||
|
@ -968,6 +968,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
# pyrefly: ignore # missing-argument
|
||||
t
|
||||
for t in self.range_trees
|
||||
# pyrefly: ignore # missing-argument
|
||||
if not t.is_reduction or self.inside_reduction
|
||||
]
|
||||
|
||||
|
@ -1004,6 +1004,7 @@ class FxConverter:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
call_kwargs[key]
|
||||
for key in signature
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if key not in cfg.kwargs
|
||||
]
|
||||
|
||||
|
@ -421,6 +421,7 @@ def get_proxy_slot(
|
||||
else:
|
||||
# Attempt to build it from first principles.
|
||||
_build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
value = tracker.get(obj)
|
||||
|
||||
if value is None:
|
||||
|
Reference in New Issue
Block a user