mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Expand type checking to mypy strict files (#165697)
Expands Pyrefly type checking to check the files outlined in the mypy-strict.ini configuration file: Pull Request resolved: https://github.com/pytorch/pytorch/pull/165697 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
9095a9dfae
commit
f02e3947f6
16
pyrefly.toml
16
pyrefly.toml
@ -5,6 +5,7 @@ python-version = "3.12"
|
|||||||
project-includes = [
|
project-includes = [
|
||||||
"torch",
|
"torch",
|
||||||
"caffe2",
|
"caffe2",
|
||||||
|
"tools",
|
||||||
"test/test_bundled_images.py",
|
"test/test_bundled_images.py",
|
||||||
"test/test_bundled_inputs.py",
|
"test/test_bundled_inputs.py",
|
||||||
"test/test_complex.py",
|
"test/test_complex.py",
|
||||||
@ -24,8 +25,11 @@ project-excludes = [
|
|||||||
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
||||||
"torch/_inductor/runtime",
|
"torch/_inductor/runtime",
|
||||||
"torch/_inductor/codegen/triton.py",
|
"torch/_inductor/codegen/triton.py",
|
||||||
|
"tools/linter/adapters/test_device_bias_linter.py",
|
||||||
|
"tools/code_analyzer/gen_operators_yaml.py",
|
||||||
# formatting issues, will turn on after adjusting where suppressions can be
|
# formatting issues, will turn on after adjusting where suppressions can be
|
||||||
# in import statements
|
# in import statements
|
||||||
|
"tools/flight_recorder/components/types.py",
|
||||||
"torch/linalg/__init__.py",
|
"torch/linalg/__init__.py",
|
||||||
"torch/package/importer.py",
|
"torch/package/importer.py",
|
||||||
"torch/package/_package_pickler.py",
|
"torch/package/_package_pickler.py",
|
||||||
@ -40,17 +44,6 @@ project-excludes = [
|
|||||||
"torch/distributed/elastic/metrics/__init__.py",
|
"torch/distributed/elastic/metrics/__init__.py",
|
||||||
"torch/_inductor/fx_passes/bucketing.py",
|
"torch/_inductor/fx_passes/bucketing.py",
|
||||||
# ====
|
# ====
|
||||||
"benchmarks/instruction_counts/main.py",
|
|
||||||
"benchmarks/instruction_counts/definitions/setup.py",
|
|
||||||
"benchmarks/instruction_counts/applications/ci.py",
|
|
||||||
"benchmarks/instruction_counts/core/api.py",
|
|
||||||
"benchmarks/instruction_counts/core/expand.py",
|
|
||||||
"benchmarks/instruction_counts/core/types.py",
|
|
||||||
"benchmarks/instruction_counts/core/utils.py",
|
|
||||||
"benchmarks/instruction_counts/definitions/standard.py",
|
|
||||||
"benchmarks/instruction_counts/definitions/setup.py",
|
|
||||||
"benchmarks/instruction_counts/execution/runner.py",
|
|
||||||
"benchmarks/instruction_counts/execution/work.py",
|
|
||||||
"torch/include/**",
|
"torch/include/**",
|
||||||
"torch/csrc/**",
|
"torch/csrc/**",
|
||||||
"torch/distributed/elastic/agent/server/api.py",
|
"torch/distributed/elastic/agent/server/api.py",
|
||||||
@ -137,3 +130,4 @@ errors.bad-param-name-override = false
|
|||||||
errors.implicit-import = false
|
errors.implicit-import = false
|
||||||
permissive-ignores = true
|
permissive-ignores = true
|
||||||
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
|
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
|
||||||
|
search-path = ["tools/experimental"]
|
||||||
|
@ -863,6 +863,7 @@ static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
|||||||
saved_variables.append(f"{type.cpp_type()} {name};")
|
saved_variables.append(f"{type.cpp_type()} {name};")
|
||||||
|
|
||||||
if type in MISC_GETTER_DEFS:
|
if type in MISC_GETTER_DEFS:
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
getter_def, body = MISC_GETTER_DEFS[type]
|
getter_def, body = MISC_GETTER_DEFS[type]
|
||||||
getter_definitions.append(
|
getter_definitions.append(
|
||||||
getter_def.substitute(op=info.op, name=name, body=body)
|
getter_def.substitute(op=info.op, name=name, body=body)
|
||||||
@ -1033,6 +1034,7 @@ static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
|||||||
unpack_ivalues = []
|
unpack_ivalues = []
|
||||||
for typ, name in zip(apply_functional_args_ref_types, apply_functional_args):
|
for typ, name in zip(apply_functional_args_ref_types, apply_functional_args):
|
||||||
typ = typ.removesuffix("&")
|
typ = typ.removesuffix("&")
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();")
|
unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();")
|
||||||
|
|
||||||
schema_args = [f"std::array<bool, {len(input_name_to_idx)}>"]
|
schema_args = [f"std::array<bool, {len(input_name_to_idx)}>"]
|
||||||
|
@ -182,6 +182,7 @@ def format_trace_inputs(f: NativeFunction) -> str:
|
|||||||
ADD_TRACE_INPUT.substitute(
|
ADD_TRACE_INPUT.substitute(
|
||||||
name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name
|
name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
for i in range(num_out_args)
|
for i in range(num_out_args)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1495,6 +1495,7 @@ def emit_body(
|
|||||||
else:
|
else:
|
||||||
expr = f"SavedVariable({var}, {str(is_output).lower()})"
|
expr = f"SavedVariable({var}, {str(is_output).lower()})"
|
||||||
if foreacharg is not None and "original_selfs" not in expr:
|
if foreacharg is not None and "original_selfs" not in expr:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
expr = expr.replace(src_name, name_in_expr)
|
expr = expr.replace(src_name, name_in_expr)
|
||||||
elif (
|
elif (
|
||||||
type == BaseCType(tensorListT)
|
type == BaseCType(tensorListT)
|
||||||
@ -1844,12 +1845,14 @@ def emit_body(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
cur_derivative_conditions.append(
|
cur_derivative_conditions.append(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
|
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
|
||||||
req_inp=inp_name + "[i]"
|
req_inp=inp_name + "[i]"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cur_derivative_conditions.append(
|
cur_derivative_conditions.append(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp_name)
|
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1920,6 +1923,7 @@ def emit_body(
|
|||||||
unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
|
unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
|
||||||
inp_name="original_self",
|
inp_name="original_self",
|
||||||
inp="original_self" + input_suffix,
|
inp="original_self" + input_suffix,
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
zeros_fn=zeros_fn,
|
zeros_fn=zeros_fn,
|
||||||
)
|
)
|
||||||
unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
|
unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
|
||||||
|
@ -95,8 +95,11 @@ def add_view_copy_derivatives(
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
# prefer manually-defined derivatives if any
|
# prefer manually-defined derivatives if any
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
|
if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
assert fn_schema is not None
|
assert fn_schema is not None
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
view_infos[fn_schema] = view_copy_differentiability_infos
|
view_infos[fn_schema] = view_copy_differentiability_infos
|
||||||
|
|
||||||
infos.update(view_infos)
|
infos.update(view_infos)
|
||||||
@ -398,6 +401,7 @@ def postprocess_forward_derivatives(
|
|||||||
for arg_name in all_arg_names:
|
for arg_name in all_arg_names:
|
||||||
if arg_name in diff_arg_names:
|
if arg_name in diff_arg_names:
|
||||||
arg_name = arg_name + "_t"
|
arg_name = arg_name + "_t"
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
new_args.append(arg_name)
|
new_args.append(arg_name)
|
||||||
|
|
||||||
# TODO we are trolling
|
# TODO we are trolling
|
||||||
@ -938,6 +942,7 @@ def saved_variables(
|
|||||||
+ f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
|
+ f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
|
||||||
)
|
)
|
||||||
for nctype in nctypes:
|
for nctype in nctypes:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
name = (
|
name = (
|
||||||
nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
|
nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
|
||||||
)
|
)
|
||||||
@ -947,6 +952,7 @@ def saved_variables(
|
|||||||
|
|
||||||
def repl(m: re.Match[str]) -> str:
|
def repl(m: re.Match[str]) -> str:
|
||||||
suffix: str = (
|
suffix: str = (
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
|
info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
|
||||||
)
|
)
|
||||||
expr: str = info["expr"](name) if "expr" in info else m.group(0)
|
expr: str = info["expr"](name) if "expr" in info else m.group(0)
|
||||||
|
@ -67,6 +67,7 @@ def is_intrested_file(
|
|||||||
|
|
||||||
# ignore files that are not belong to pytorch
|
# ignore files that are not belong to pytorch
|
||||||
if platform == TestPlatform.OSS:
|
if platform == TestPlatform.OSS:
|
||||||
|
# pyrefly: ignore # import-error
|
||||||
from package.oss.utils import get_pytorch_folder
|
from package.oss.utils import get_pytorch_folder
|
||||||
|
|
||||||
if not file_path.startswith(get_pytorch_folder()):
|
if not file_path.startswith(get_pytorch_folder()):
|
||||||
|
@ -24,6 +24,7 @@ def report_download_progress(
|
|||||||
file_size: int,
|
file_size: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
if file_size != -1:
|
if file_size != -1:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
percent = min(1, (chunk_number * chunk_size) / file_size)
|
percent = min(1, (chunk_number * chunk_size) / file_size)
|
||||||
bar = "#" * int(64 * percent)
|
bar = "#" * int(64 * percent)
|
||||||
sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")
|
sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")
|
||||||
|
@ -105,8 +105,10 @@ def extract_info_from_keyword(source: str, kw: ast.keyword) -> Any:
|
|||||||
evaluated_context = []
|
evaluated_context = []
|
||||||
for value in kw.value.values:
|
for value in kw.value.values:
|
||||||
if isinstance(value, ast.FormattedValue):
|
if isinstance(value, ast.FormattedValue):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
evaluated_context.append(f"{{{ast.unparse(value.value)}}}")
|
evaluated_context.append(f"{{{ast.unparse(value.value)}}}")
|
||||||
elif isinstance(value, ast.Constant):
|
elif isinstance(value, ast.Constant):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
evaluated_context.append(value.value)
|
evaluated_context.append(value.value)
|
||||||
return "".join(evaluated_context)
|
return "".join(evaluated_context)
|
||||||
else:
|
else:
|
||||||
@ -152,6 +154,7 @@ def find_unimplemented_v2_calls(
|
|||||||
|
|
||||||
for kw in node.keywords:
|
for kw in node.keywords:
|
||||||
if kw.arg in info:
|
if kw.arg in info:
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
info[kw.arg] = extract_info_from_keyword(source, kw)
|
info[kw.arg] = extract_info_from_keyword(source, kw)
|
||||||
|
|
||||||
if info["gb_type"] is None:
|
if info["gb_type"] is None:
|
||||||
|
@ -296,6 +296,7 @@ def run_multi_process_fuzzer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def write_func(msg):
|
def write_func(msg):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
pbar.write(msg)
|
pbar.write(msg)
|
||||||
else:
|
else:
|
||||||
persist_print("Progress: (install tqdm for better progress bar)")
|
persist_print("Progress: (install tqdm for better progress bar)")
|
||||||
|
@ -111,6 +111,7 @@ class ConstantOperator(Operator):
|
|||||||
]:
|
]:
|
||||||
# Clamp integer values to [0, 3] to avoid index overflow in multiplication
|
# Clamp integer values to [0, 3] to avoid index overflow in multiplication
|
||||||
# Even with multiplication, indices should stay in reasonable range
|
# Even with multiplication, indices should stay in reasonable range
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
fill_value = max(0, min(3, abs(fill_value)))
|
fill_value = max(0, min(3, abs(fill_value)))
|
||||||
|
|
||||||
tensor_creation = (
|
tensor_creation = (
|
||||||
|
@ -78,15 +78,22 @@ class JobConfig:
|
|||||||
def parse_args(
|
def parse_args(
|
||||||
self: "JobConfig", args: Optional[Sequence[str]]
|
self: "JobConfig", args: Optional[Sequence[str]]
|
||||||
) -> argparse.Namespace:
|
) -> argparse.Namespace:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
args = self.parser.parse_args(args)
|
args = self.parser.parse_args(args)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if args.selected_ranks is not None:
|
if args.selected_ranks is not None:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
assert args.just_print_entries, (
|
assert args.just_print_entries, (
|
||||||
"Not support selecting ranks without printing entries"
|
"Not support selecting ranks without printing entries"
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if args.pg_filters is not None:
|
if args.pg_filters is not None:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
assert args.just_print_entries, (
|
assert args.just_print_entries, (
|
||||||
"Not support selecting pg filters without printing entries"
|
"Not support selecting pg filters without printing entries"
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
logger.set_log_level(logging.DEBUG)
|
logger.set_log_level(logging.DEBUG)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return args
|
return args
|
||||||
|
@ -41,6 +41,7 @@ def format_frame(frame: dict[str, str]) -> str:
|
|||||||
def format_frames(frames: list[dict[str, str]]) -> str:
|
def format_frames(frames: list[dict[str, str]]) -> str:
|
||||||
formatted_frames = []
|
formatted_frames = []
|
||||||
for frame in frames:
|
for frame in frames:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
formatted_frames.append(format_frame(frame))
|
formatted_frames.append(format_frame(frame))
|
||||||
return "\n".join(formatted_frames)
|
return "\n".join(formatted_frames)
|
||||||
|
|
||||||
@ -695,6 +696,7 @@ def check_version(version_by_ranks: dict[str, str], version: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def get_version_detail(version: str) -> tuple[int, int]:
|
def get_version_detail(version: str) -> tuple[int, int]:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
version = version.split(".")
|
version = version.split(".")
|
||||||
assert len(version) == 2, f"Invalid version {version}"
|
assert len(version) == 2, f"Invalid version {version}"
|
||||||
major, minor = map(int, version)
|
major, minor = map(int, version)
|
||||||
|
@ -40,11 +40,17 @@ from tools.flight_recorder.components.types import types
|
|||||||
|
|
||||||
def main(args: Optional[Sequence[str]] = None) -> None:
|
def main(args: Optional[Sequence[str]] = None) -> None:
|
||||||
config = JobConfig()
|
config = JobConfig()
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
args = config.parse_args(args)
|
args = config.parse_args(args)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
assert args.trace_dir, "Trace directory trace_dir is required"
|
assert args.trace_dir, "Trace directory trace_dir is required"
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
details, version = read_dir(args)
|
details, version = read_dir(args)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
db = build_db(details, args, version)
|
db = build_db(details, args, version)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if args.output:
|
if args.output:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
with open(args.output, "wb") as f:
|
with open(args.output, "wb") as f:
|
||||||
pickle.dump((types, db), f)
|
pickle.dump((types, db), f)
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ class TensorRepr(gdb.Command): # type: ignore[misc, no-any-unimported]
|
|||||||
on it.
|
on it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
__doc__ = textwrap.dedent(__doc__).strip()
|
__doc__ = textwrap.dedent(__doc__).strip()
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -118,6 +118,7 @@ def extract_filename(path: str, keep_ext: bool = True) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
# https://gist.github.com/pypt/94d747fe5180851196eb
|
# https://gist.github.com/pypt/94d747fe5180851196eb
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
class UniqueKeyLoader(Loader):
|
class UniqueKeyLoader(Loader):
|
||||||
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
||||||
if not isinstance(node, MappingNode):
|
if not isinstance(node, MappingNode):
|
||||||
@ -233,6 +234,7 @@ def preprocess(
|
|||||||
last_indent = input_indent
|
last_indent = input_indent
|
||||||
|
|
||||||
while blank_lines != 0:
|
while blank_lines != 0:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
python_lines.append(python_indent + "print(file=OUT_STREAM)")
|
python_lines.append(python_indent + "print(file=OUT_STREAM)")
|
||||||
blank_lines -= 1
|
blank_lines -= 1
|
||||||
|
|
||||||
@ -667,6 +669,7 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
|
|||||||
" ",
|
" ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
return shader_dispatch_str
|
return shader_dispatch_str
|
||||||
|
|
||||||
|
|
||||||
@ -681,15 +684,18 @@ def genCppFiles(
|
|||||||
name = getName(spvPath).replace("_spv", "")
|
name = getName(spvPath).replace("_spv", "")
|
||||||
|
|
||||||
sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
|
sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
spv_bin_strs.append(spv_bin_str)
|
spv_bin_strs.append(spv_bin_str)
|
||||||
|
|
||||||
shader_info = getShaderInfo(srcPath)
|
shader_info = getShaderInfo(srcPath)
|
||||||
|
|
||||||
register_shader_info_strs.append(
|
register_shader_info_strs.append(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
generateShaderInfoStr(shader_info, name, sizeBytes)
|
generateShaderInfoStr(shader_info, name, sizeBytes)
|
||||||
)
|
)
|
||||||
|
|
||||||
if shader_info.register_for is not None:
|
if shader_info.register_for is not None:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
|
shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
|
||||||
|
|
||||||
spv_bin_arrays = "\n".join(spv_bin_strs)
|
spv_bin_arrays = "\n".join(spv_bin_strs)
|
||||||
|
@ -131,12 +131,14 @@ class ComputeCodegenUnboxedKernels:
|
|||||||
else:
|
else:
|
||||||
arg_cpp = f"c10::IValue({arg_default})"
|
arg_cpp = f"c10::IValue({arg_default})"
|
||||||
args_code.append(
|
args_code.append(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
f"""c10::Argument("{arg.name}", nullptr, ::std::nullopt, {arg_cpp})"""
|
f"""c10::Argument("{arg.name}", nullptr, ::std::nullopt, {arg_cpp})"""
|
||||||
)
|
)
|
||||||
|
|
||||||
returns = f.func.returns
|
returns = f.func.returns
|
||||||
returns_code = []
|
returns_code = []
|
||||||
for ret in returns:
|
for ret in returns:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
returns_code.append(f"""c10::Argument("{ret.name if ret.name else ""}")""")
|
returns_code.append(f"""c10::Argument("{ret.name if ret.name else ""}")""")
|
||||||
return f"""
|
return f"""
|
||||||
// aten::{schema}
|
// aten::{schema}
|
||||||
|
@ -112,6 +112,7 @@ class FileLinter:
|
|||||||
first_results = None
|
first_results = None
|
||||||
original = replacement = pf.contents
|
original = replacement = pf.contents
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
results = sorted(self._lint(pf), key=LintResult.sort_key)
|
results = sorted(self._lint(pf), key=LintResult.sort_key)
|
||||||
|
@ -41,6 +41,7 @@ class LineWithSets:
|
|||||||
t = self.tokens[i]
|
t = self.tokens[i]
|
||||||
after = i < len(self.tokens) - 1 and self.tokens[i + 1]
|
after = i < len(self.tokens) - 1 and self.tokens[i + 1]
|
||||||
if t.string == "Set" and t.type == token.NAME:
|
if t.string == "Set" and t.type == token.NAME:
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return after and after.string == "[" and after.type == token.OP
|
return after and after.string == "[" and after.type == token.OP
|
||||||
return (
|
return (
|
||||||
(t.string == "set" and t.type == token.NAME)
|
(t.string == "set" and t.type == token.NAME)
|
||||||
|
@ -19,11 +19,13 @@ from typing import NamedTuple
|
|||||||
# PyTorch directory root
|
# PyTorch directory root
|
||||||
def scm_root() -> str:
|
def scm_root() -> str:
|
||||||
path = os.path.abspath(os.getcwd())
|
path = os.path.abspath(os.getcwd())
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
while True:
|
while True:
|
||||||
if os.path.exists(os.path.join(path, ".git")):
|
if os.path.exists(os.path.join(path, ".git")):
|
||||||
return path
|
return path
|
||||||
if os.path.isdir(os.path.join(path, ".hg")):
|
if os.path.isdir(os.path.join(path, ".hg")):
|
||||||
return path
|
return path
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
n = len(path)
|
n = len(path)
|
||||||
path = os.path.dirname(path)
|
path = os.path.dirname(path)
|
||||||
if len(path) == n:
|
if len(path) == n:
|
||||||
|
@ -101,6 +101,7 @@ def check_dictionary(filename: str) -> list[LintMessage]:
|
|||||||
words_set = set(words)
|
words_set = set(words)
|
||||||
if len(words) != len(words_set):
|
if len(words) != len(words_set):
|
||||||
raise ValueError("The dictionary file contains duplicate entries.")
|
raise ValueError("The dictionary file contains duplicate entries.")
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
uncased_words = list(map(str.lower, words))
|
uncased_words = list(map(str.lower, words))
|
||||||
if uncased_words != sorted(uncased_words):
|
if uncased_words != sorted(uncased_words):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -12,6 +12,7 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
# pyrefly: ignore # import-error
|
||||||
import isort
|
import isort
|
||||||
import usort
|
import usort
|
||||||
|
|
||||||
|
@ -55,6 +55,7 @@ def report_download_progress(
|
|||||||
Pretty printer for file download progress.
|
Pretty printer for file download progress.
|
||||||
"""
|
"""
|
||||||
if file_size != -1:
|
if file_size != -1:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
percent = min(1, (chunk_number * chunk_size) / file_size)
|
percent = min(1, (chunk_number * chunk_size) / file_size)
|
||||||
bar = "#" * int(64 * percent)
|
bar = "#" * int(64 * percent)
|
||||||
sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")
|
sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")
|
||||||
|
@ -15,7 +15,10 @@ import multiprocessing as mp
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
# pyrefly: ignore # import-error
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
|
|
||||||
|
# pyrefly: ignore # import-error
|
||||||
import libcst.matchers as m
|
import libcst.matchers as m
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,6 +69,7 @@ def print_lint_message(path: Path, job: dict[str, Any], sync_tag: str) -> None:
|
|||||||
|
|
||||||
lint_message = LintMessage(
|
lint_message = LintMessage(
|
||||||
path=str(path),
|
path=str(path),
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
line=line_number,
|
line=line_number,
|
||||||
char=None,
|
char=None,
|
||||||
code="WORKFLOWSYNC",
|
code="WORKFLOWSYNC",
|
||||||
|
@ -73,6 +73,7 @@ def get_selected_kernel_dtypes_code(
|
|||||||
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
|
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
|
||||||
conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
|
conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
|
||||||
body_parts.append(
|
body_parts.append(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if_condition_template.substitute(
|
if_condition_template.substitute(
|
||||||
kernel_tag_name=kernel_tag,
|
kernel_tag_name=kernel_tag,
|
||||||
dtype_checks=" || ".join(conditions),
|
dtype_checks=" || ".join(conditions),
|
||||||
|
@ -311,6 +311,7 @@ class Venv:
|
|||||||
python=python,
|
python=python,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
).stdout
|
).stdout
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
candidates = list(map(Path, filter(None, map(str.strip, output.splitlines()))))
|
candidates = list(map(Path, filter(None, map(str.strip, output.splitlines()))))
|
||||||
candidates = [p for p in candidates if p.is_dir() and p.name == "site-packages"]
|
candidates = [p for p in candidates if p.is_dir() and p.name == "site-packages"]
|
||||||
if not candidates:
|
if not candidates:
|
||||||
@ -480,6 +481,7 @@ class Venv:
|
|||||||
cmd = [str(python), *args]
|
cmd = [str(python), *args]
|
||||||
env = popen_kwargs.pop("env", None) or {}
|
env = popen_kwargs.pop("env", None) or {}
|
||||||
check = popen_kwargs.pop("check", True)
|
check = popen_kwargs.pop("check", True)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return subprocess.run(
|
return subprocess.run(
|
||||||
cmd,
|
cmd,
|
||||||
check=check,
|
check=check,
|
||||||
@ -531,6 +533,7 @@ class Venv:
|
|||||||
cmd = [str(self.bindir / "uv"), *args]
|
cmd = [str(self.bindir / "uv"), *args]
|
||||||
env = popen_kwargs.pop("env", None) or {}
|
env = popen_kwargs.pop("env", None) or {}
|
||||||
check = popen_kwargs.pop("check", True)
|
check = popen_kwargs.pop("check", True)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return subprocess.run(
|
return subprocess.run(
|
||||||
cmd,
|
cmd,
|
||||||
check=check,
|
check=check,
|
||||||
@ -938,6 +941,7 @@ def _move_single(
|
|||||||
|
|
||||||
def _copy_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
|
def _copy_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
|
||||||
for src in listing:
|
for src in listing:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
|
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
|
||||||
|
|
||||||
|
|
||||||
|
@ -118,6 +118,7 @@ def download_patch(pr_number: int, repo_url: str, download_dir: str) -> str:
|
|||||||
urllib.request.urlopen(patch_url) as response,
|
urllib.request.urlopen(patch_url) as response,
|
||||||
open(patch_file, "wb") as out_file,
|
open(patch_file, "wb") as out_file,
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # bad-specialization
|
||||||
shutil.copyfileobj(response, out_file)
|
shutil.copyfileobj(response, out_file)
|
||||||
if not os.path.isfile(patch_file):
|
if not os.path.isfile(patch_file):
|
||||||
print(f"Failed to download patch for PR #{pr_number}")
|
print(f"Failed to download patch for PR #{pr_number}")
|
||||||
|
@ -994,6 +994,7 @@ def add_docstr_to_hint(docstr: str, hint: str) -> str:
|
|||||||
hint = hint.removesuffix("...").rstrip() # remove "..."
|
hint = hint.removesuffix("...").rstrip() # remove "..."
|
||||||
content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ")
|
content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ")
|
||||||
# Remove trailing whitespace on each line
|
# Remove trailing whitespace on each line
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
|
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
|
||||||
|
|
||||||
# attribute or property
|
# attribute or property
|
||||||
|
@ -100,6 +100,7 @@ class CMake:
|
|||||||
if ver is not None:
|
if ver is not None:
|
||||||
eprint(f"Found {cmd} ({command}) version: {ver}", end="")
|
eprint(f"Found {cmd} ({command}) version: {ver}", end="")
|
||||||
cmake_versions.append(f"{cmd}=={ver}")
|
cmake_versions.append(f"{cmd}=={ver}")
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
if ver >= CMAKE_MINIMUM_VERSION:
|
if ver >= CMAKE_MINIMUM_VERSION:
|
||||||
eprint(f" (>={CMAKE_MINIMUM_VERSION})")
|
eprint(f" (>={CMAKE_MINIMUM_VERSION})")
|
||||||
valid_cmake_versions[cmd] = ver
|
valid_cmake_versions[cmd] = ver
|
||||||
|
@ -31,7 +31,9 @@ def gen_linker_script(
|
|||||||
text_line_start = text_line_start[0]
|
text_line_start = text_line_start[0]
|
||||||
|
|
||||||
# ensure that parent directory exists before writing
|
# ensure that parent directory exists before writing
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
fout = Path(fout)
|
fout = Path(fout)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
fout.parent.mkdir(parents=True, exist_ok=True)
|
fout.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with open(fout, "w") as f:
|
with open(fout, "w") as f:
|
||||||
|
@ -60,6 +60,7 @@ class SegmentGenerator:
|
|||||||
df[time_col_name] = pd.to_datetime(df[time_col_name], unit="s", utc=True)
|
df[time_col_name] = pd.to_datetime(df[time_col_name], unit="s", utc=True)
|
||||||
|
|
||||||
# get unique cmd names
|
# get unique cmd names
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
unique_cmds_df = pd.DataFrame(df[cmd_col_name].unique(), columns=[cmd_col_name])
|
unique_cmds_df = pd.DataFrame(df[cmd_col_name].unique(), columns=[cmd_col_name])
|
||||||
|
|
||||||
# get all detected python cmds
|
# get all detected python cmds
|
||||||
|
@ -7,6 +7,7 @@ import unittest
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
# pyrefly: ignore # import-error
|
||||||
from gen_operators_yaml import (
|
from gen_operators_yaml import (
|
||||||
fill_output,
|
fill_output,
|
||||||
get_parser_options,
|
get_parser_options,
|
||||||
@ -241,5 +242,6 @@ class GenOperatorsYAMLTest(unittest.TestCase):
|
|||||||
|
|
||||||
fill_output(output, options)
|
fill_output(output, options)
|
||||||
|
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
for op_val in output["operators"].values():
|
for op_val in output["operators"].values():
|
||||||
self.assertFalse(op_val["include_all_overloads"])
|
self.assertFalse(op_val["include_all_overloads"])
|
||||||
|
@ -88,6 +88,7 @@ operators:
|
|||||||
self.assertTrue(selector2.is_operator_selected("aten::sub.int"))
|
self.assertTrue(selector2.is_operator_selected("aten::sub.int"))
|
||||||
|
|
||||||
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
["aten::add", "aten::add.int", "aten::mul.int"],
|
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
@ -103,6 +104,7 @@ operators:
|
|||||||
)
|
)
|
||||||
|
|
||||||
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
["aten::add", "aten::add.int", "aten::mul.int"],
|
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||||
True,
|
True,
|
||||||
False,
|
False,
|
||||||
@ -118,6 +120,7 @@ operators:
|
|||||||
)
|
)
|
||||||
|
|
||||||
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
["aten::add", "aten::add.int", "aten::mul.int"],
|
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||||
False,
|
False,
|
||||||
True,
|
True,
|
||||||
|
@ -83,7 +83,9 @@ def _rank_correlated_tests(
|
|||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
# Find the tests failures that are correlated with the edited files.
|
# Find the tests failures that are correlated with the edited files.
|
||||||
# Filter the list to only include tests we want to run.
|
# Filter the list to only include tests we want to run.
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
tests_to_run = set(tests_to_run)
|
tests_to_run = set(tests_to_run)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
ratings = _get_ratings_for_tests(tests_to_run)
|
ratings = _get_ratings_for_tests(tests_to_run)
|
||||||
prioritize = sorted(ratings, key=lambda x: -ratings[x])
|
prioritize = sorted(ratings, key=lambda x: -ratings[x])
|
||||||
return prioritize
|
return prioritize
|
||||||
|
@ -36,11 +36,13 @@ def concated_logs() -> str:
|
|||||||
for log_file in glob.glob(
|
for log_file in glob.glob(
|
||||||
f"{REPO_ROOT}/test/test-reports/**/*.log", recursive=True
|
f"{REPO_ROOT}/test/test-reports/**/*.log", recursive=True
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
logs.append(f"=== {log_file} ===")
|
logs.append(f"=== {log_file} ===")
|
||||||
with open(log_file) as f:
|
with open(log_file) as f:
|
||||||
# For every line, prefix with fake timestamp for log classifier
|
# For every line, prefix with fake timestamp for log classifier
|
||||||
for line in f:
|
for line in f:
|
||||||
line = line.rstrip("\n") # Remove any trailing newline
|
line = line.rstrip("\n") # Remove any trailing newline
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
logs.append(f"2020-01-01T00:00:00.0000000Z {line}")
|
logs.append(f"2020-01-01T00:00:00.0000000Z {line}")
|
||||||
return "\n".join(logs)
|
return "\n".join(logs)
|
||||||
|
|
||||||
|
@ -1739,6 +1739,7 @@ class KernelArgs:
|
|||||||
for outer, inner in chain(
|
for outer, inner in chain(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.input_buffers.items(),
|
self.input_buffers.items(),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.output_buffers.items(),
|
self.output_buffers.items(),
|
||||||
):
|
):
|
||||||
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
|
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
|
||||||
|
@ -1480,6 +1480,7 @@ class CppGemmTemplate(CppTemplate):
|
|||||||
gemm_output_buffer = ir.Buffer(
|
gemm_output_buffer = ir.Buffer(
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
name=gemm_output_name,
|
name=gemm_output_name,
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
layout=template_buffer.layout,
|
layout=template_buffer.layout,
|
||||||
)
|
)
|
||||||
current_input_buffer = gemm_output_buffer
|
current_input_buffer = gemm_output_buffer
|
||||||
@ -1503,6 +1504,7 @@ class CppGemmTemplate(CppTemplate):
|
|||||||
current_input_buffer = ir.Buffer(
|
current_input_buffer = ir.Buffer(
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
name=buffer_name,
|
name=buffer_name,
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
layout=template_buffer.layout,
|
layout=template_buffer.layout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -824,6 +824,7 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||||||
call_args, arg_types = self.prepare_triton_wrapper_args(
|
call_args, arg_types = self.prepare_triton_wrapper_args(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
call_args,
|
call_args,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
arg_types,
|
arg_types,
|
||||||
)
|
)
|
||||||
wrapper_name = f"call_{kernel_name}"
|
wrapper_name = f"call_{kernel_name}"
|
||||||
|
@ -683,6 +683,7 @@ class MetalKernel(SIMDKernel):
|
|||||||
# pyrefly: ignore # missing-argument
|
# pyrefly: ignore # missing-argument
|
||||||
t
|
t
|
||||||
for t in self.range_tree_nodes.values()
|
for t in self.range_tree_nodes.values()
|
||||||
|
# pyrefly: ignore # missing-argument
|
||||||
if t.is_reduction
|
if t.is_reduction
|
||||||
)
|
)
|
||||||
cmp_op = ">" if reduction_type == "argmax" else "<"
|
cmp_op = ">" if reduction_type == "argmax" else "<"
|
||||||
@ -865,6 +866,7 @@ class MetalKernel(SIMDKernel):
|
|||||||
# pyrefly: ignore # missing-argument
|
# pyrefly: ignore # missing-argument
|
||||||
t.numel
|
t.numel
|
||||||
for t in self.range_trees
|
for t in self.range_trees
|
||||||
|
# pyrefly: ignore # missing-argument
|
||||||
if t.is_reduction
|
if t.is_reduction
|
||||||
)
|
)
|
||||||
# If using dynamic shapes, set the threadgroup size to be the
|
# If using dynamic shapes, set the threadgroup size to be the
|
||||||
|
@ -968,6 +968,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
|||||||
# pyrefly: ignore # missing-argument
|
# pyrefly: ignore # missing-argument
|
||||||
t
|
t
|
||||||
for t in self.range_trees
|
for t in self.range_trees
|
||||||
|
# pyrefly: ignore # missing-argument
|
||||||
if not t.is_reduction or self.inside_reduction
|
if not t.is_reduction or self.inside_reduction
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1004,6 +1004,7 @@ class FxConverter:
|
|||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
call_kwargs[key]
|
call_kwargs[key]
|
||||||
for key in signature
|
for key in signature
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if key not in cfg.kwargs
|
if key not in cfg.kwargs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -421,6 +421,7 @@ def get_proxy_slot(
|
|||||||
else:
|
else:
|
||||||
# Attempt to build it from first principles.
|
# Attempt to build it from first principles.
|
||||||
_build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
|
_build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
value = tracker.get(obj)
|
value = tracker.get(obj)
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
|
Reference in New Issue
Block a user