Pyrefly suppressions 4/n (#164615)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1

after:

0 errors (2,753 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-06 16:14:36 +00:00
committed by PyTorch MergeBot
parent 4bd1505f84
commit 4ab847bbc7
52 changed files with 293 additions and 21 deletions

View File

@ -29,15 +29,11 @@ project-excludes = [
"torch/fx/**",
"torch/distributions/**",
"torch/onnx/**",
"torch/_refs/**",
"torch/_export/**",
"torch/jit/**",
"torch/optim/**",
"torch/_higher_order_ops/**",
# formatting issues
"torch/linalg/__init__.py",
"torch/package/importer.py",
"torch/package/_package_pickler.py",
"torch/jit/annotations.py",
# ====
"benchmarks/instruction_counts/main.py",
"benchmarks/instruction_counts/definitions/setup.py",

View File

@ -1097,6 +1097,7 @@ class TS2FXGraphConverter:
# Update the value of loop local variables.
if node.outputsSize() >= 1:
# pyrefly: ignore # bad-assignment
for i, outp in enumerate(node.outputs()):
output_name = outp.debugName()
self.name_to_node[output_name] = self.fx_graph.call_function(
@ -1109,6 +1110,7 @@ class TS2FXGraphConverter:
fx_block_args[i] = self.name_to_node[output_name]
# Update the value of global variables, whose values are modified inplace.
# pyrefly: ignore # bad-assignment
for i, name in enumerate(
subgraph_converter.name_update_from_subblock_to_parent
):

View File

@ -132,6 +132,7 @@ def _make_export_case(m, name, configs):
m.__doc__ is not None
), f"Could not find description or docstring for export case: {m}"
configs = {**configs, "description": m.__doc__}
# pyrefly: ignore # bad-argument-type
return ExportCase(**{**configs, "model": m, "name": name})

View File

@ -3,10 +3,12 @@ import torch
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, x):
return x.clone()
@staticmethod
# pyrefly: ignore # bad-override
def backward(ctx, grad_output):
return grad_output + 1

View File

@ -39,6 +39,7 @@ def get_class_if_classified_error(e: Exception) -> Optional[str]:
TorchRuntimeError: None,
}
if type(e) in _ALLOW_LIST:
# pyrefly: ignore # index-error
attr_name = _ALLOW_LIST[type(e)]
if attr_name is None:
return ALWAYS_CLASSIFIED

View File

@ -101,6 +101,7 @@ class _KeyPathTrie:
assert len(kp) > 0
k, *kp = kp # type: ignore[assignment]
node = node[k]
# pyrefly: ignore # bad-return
return node, kp
@ -139,6 +140,7 @@ def key_path_to_source(
source: Source = LocalSource("args")
else:
source, kp = sourced_prefixes.get(kp)
# pyrefly: ignore # bad-assignment
for k in kp:
if isinstance(k, SequenceKey):
source = GetItemSource(source, k.idx)
@ -354,10 +356,12 @@ def _override_builtin_ops():
original_min = builtins.min
original_pow = math.pow
# pyrefly: ignore # bad-assignment
builtins.max = functools.partial(
_tensor_min_max, real_callable=original_max, tensor_callable=torch.maximum
)
# pyrefly: ignore # bad-assignment
builtins.min = functools.partial(
_tensor_min_max, real_callable=original_min, tensor_callable=torch.minimum
)
@ -1083,6 +1087,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
def run():
# Run sequence.
# pyrefly: ignore # index-error
t = args[0]
for _method, _args in sequence:
t = _method(t, *_args)

View File

@ -188,6 +188,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self.callback = callback
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
# pyrefly: ignore # bad-override
def placeholder(
self,
target: str, # type: ignore[override]
@ -439,6 +440,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
)
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
interpreter = self.ExportInterpreter(self, graph_module)
# pyrefly: ignore # bad-assignment
prev_interpreter, self.interpreter = (
self.interpreter,
torch.fx.Interpreter( # type: ignore[assignment]

View File

@ -32,6 +32,7 @@ def _node_metadata_hook(
that nodes being added are only call_function nodes, and copies over the
first argument node's nn_module_stack.
"""
# pyrefly: ignore # bad-assignment
fake_mode = fake_mode or contextlib.nullcontext()
assert node.op == "call_function" and callable(node.target), (
@ -47,6 +48,7 @@ def _node_metadata_hook(
fake_args, fake_kwargs = pytree.tree_map_only(
torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs)
)
# pyrefly: ignore # bad-context-manager
with fake_mode, enable_python_dispatcher():
fake_res = node.target(*fake_args, **fake_kwargs)
node.meta["val"] = fake_res
@ -81,7 +83,9 @@ def _node_metadata_hook(
node.meta["torch_fn"] = node.meta.get(
"torch_fn",
(
# pyrefly: ignore # missing-attribute
f"{node.target.__name__}_0",
# pyrefly: ignore # missing-attribute
f"{node.target.__class__.__name__}.{node.target.__name__}",
),
)

View File

@ -567,6 +567,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
quantized = False
last_quantized_node = None
# pyrefly: ignore # bad-assignment
for node in gm.graph.nodes:
if isinstance(node.target, OpOverload):
with gm.graph.inserting_before(node):
@ -629,6 +630,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
attr_names_to_clean.add(k)
if k == "_buffers":
buffer_name_to_clean = set()
# pyrefly: ignore # missing-attribute
for b_name, b_value in v.items():
if isinstance(b_value, torch.Tensor) and b_value.dtype in [
torch.qint8,
@ -636,6 +638,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
]:
buffer_name_to_clean.add(b_name)
for b_name in buffer_name_to_clean:
# pyrefly: ignore # missing-attribute
v.pop(b_name, None)
for attr_name in attr_names_to_clean:
delattr(submod, attr_name)

View File

@ -35,6 +35,7 @@ def _replace_with_hop_helper(
)
call_func_node.meta["torch_fn"] = (
f"{wrap_hoo.__name__}",
# pyrefly: ignore # missing-attribute
f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}",
)
if isinstance(output_args, (tuple, list)):

View File

@ -54,6 +54,7 @@ def _postprocess_serialized_shapes(
)
for k, v in sorted(dims.items())
}
# pyrefly: ignore # bad-argument-type
spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims)
if to_dict:
return _dataclass_to_dict(spec)
@ -183,6 +184,7 @@ def _dump_dynamic_shapes(
kwargs = kwargs or {}
if isinstance(dynamic_shapes, dict):
dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment]
# pyrefly: ignore # bad-assignment, bad-argument-type
dynamic_shapes = tuple(dynamic_shapes)
combined_args = tuple(args) + tuple(kwargs.values())

View File

@ -623,7 +623,9 @@ class _Commit:
def update_schema():
import importlib.resources
# pyrefly: ignore # bad-argument-type
if importlib.resources.is_resource(__package__, "schema.yaml"):
# pyrefly: ignore # bad-argument-type
content = importlib.resources.read_text(__package__, "schema.yaml")
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
_check(match is not None, "checksum not found in schema.yaml")
@ -631,7 +633,9 @@ def update_schema():
checksum_head = match.group(1)
thrift_content = importlib.resources.read_text(
__package__, "export_schema.thrift"
# pyrefly: ignore # bad-argument-type
__package__,
"export_schema.thrift",
)
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content)
_check(match is not None, "checksum not found in export_schema.thrift")
@ -654,7 +658,9 @@ def update_schema():
src, cpp_header, thrift_schema = _staged_schema()
additions, subtractions = _diff_schema(dst, src)
# pyrefly: ignore # missing-attribute
yaml_path = __package__.replace(".", "/") + "/schema.yaml"
# pyrefly: ignore # missing-attribute
thrift_schema_path = __package__.replace(".", "/") + "/export_schema.thrift"
torch_prefix = "torch/"
assert yaml_path.startswith(torch_prefix) # sanity check

View File

@ -383,6 +383,7 @@ def _reconstruct_fake_tensor(
fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta)
if is_parameter:
fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment]
# pyrefly: ignore # bad-return
return fake_tensor
@ -2740,6 +2741,7 @@ class GraphModuleDeserializer(metaclass=Final):
serialized_node.metadata
)
assert arg is not None
# pyrefly: ignore # bad-argument-type
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
fx_node.meta["val"] = tuple(meta_val)
self.serialized_name_to_node[fx_node.name] = fx_node
@ -3165,6 +3167,7 @@ def _dict_to_dataclass(cls, data):
_value = next(iter(data.values()))
assert isinstance(_type, str)
field_type = cls.__annotations__[_type]
# pyrefly: ignore # missing-attribute
return cls.create(**{_type: _dict_to_dataclass(field_type, _value)})
elif dataclasses.is_dataclass(cls):
fields = {}
@ -3471,18 +3474,23 @@ def _canonicalize_graph(
n.metadata.clear()
# Stage 4: Aggregate values.
# pyrefly: ignore # no-matching-overload
sorted_tensor_values = dict(
sorted(graph.tensor_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
sorted_sym_int_values = dict(
sorted(graph.sym_int_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
sorted_sym_float_values = dict(
sorted(graph.sym_float_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
sorted_sym_bool_values = dict(
sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
sorted_custom_obj_values = dict(
sorted(graph.custom_obj_values.items(), key=operator.itemgetter(0))
)
@ -3539,6 +3547,7 @@ def canonicalize(
ExportedProgram: The canonicalized exported program.
"""
ep = copy.deepcopy(ep)
# pyrefly: ignore # annotation-mismatch
constants: set[str] = constants or set()
opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0)))

View File

@ -34,7 +34,7 @@ from torch.fx._pytree import (
_deregister_pytree_flatten_spec,
register_pytree_flatten_spec,
)
from torch.utils._pytree import (
from torch.utils._pytree import ( # pyrefly: ignore # deprecated
_deregister_pytree_node,
_register_pytree_node,
Context,
@ -470,7 +470,14 @@ def _check_input_constraints_for_graph(
)
elif isinstance(node_val, torch.SymInt):
_check_symint(
node_val, arg, range_constraints, unification_map, key_path, None
# pyrefly: ignore # bad-argument-type
node_val,
# pyrefly: ignore # bad-argument-type
arg,
range_constraints,
unification_map,
key_path,
None,
)
@ -1115,12 +1122,14 @@ def placeholder_naming_pass(
if ( # handle targets for custom objects
spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map
):
# pyrefly: ignore # index-error
spec.target = name_map[spec.target][4:] # strip obj_ prefix
for spec in export_graph_signature.output_specs:
if spec.arg.name in name_map:
spec.arg.name = name_map[spec.arg.name]
if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map:
# pyrefly: ignore # index-error
spec.target = name_map[spec.target]
# rename keys in constants dict for custom objects

View File

@ -96,6 +96,7 @@ class AssociativeScanOp(HigherOrderOperator):
validate_subgraph_args_types(additional_inputs)
return super().__call__(combine_fn, xs, additional_inputs)
# pyrefly: ignore # bad-override
def gen_schema(self, combine_fn, xs, additional_inputs):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import materialize_as_graph
@ -648,6 +649,7 @@ class AssociativeScanAutogradOp(torch.autograd.Function):
"""
@staticmethod
# pyrefly: ignore # bad-override
def forward(
ctx,
combine_fn,

View File

@ -609,6 +609,7 @@ def do_auto_functionalize_v2(
normalized_kwargs = {}
schema = op._schema
# pyrefly: ignore # bad-assignment
op = op._op if isinstance(op, HopInstance) else op
assert isinstance(op, get_args(_MutableOpType))

View File

@ -170,6 +170,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
out = self(functionalized_subgraph, *unwrapped_operands, **kwargs)
return ctx.wrap_tensors(out)
# pyrefly: ignore # bad-override
def gen_schema(self, subgraph, *operands, **kwargs):
from .schema import HopSchemaGenerator
@ -214,6 +215,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
class BaseHOPFunction(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, hop, subgraph, kwargs, *operands):
ctx.hop = hop
ctx.operands = operands

View File

@ -52,6 +52,7 @@ class CondOp(HigherOrderOperator):
validate_subgraph_args_types(operands)
return super().__call__(pred, true_fn, false_fn, operands)
# pyrefly: ignore # bad-override
def gen_schema(self, pred, true_fn, false_fn, operands):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import materialize_as_graph
@ -284,6 +285,7 @@ def cond_op_dense(pred, true_fn, false_fn, operands):
class CondAutogradOp(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(
ctx,
pred,

View File

@ -298,4 +298,5 @@ def handle_effects(
assert isinstance(wrapped_token, torch.Tensor)
tokens[key] = wrapped_token
# pyrefly: ignore # bad-argument-type
return ctx.wrap_tensors(unwrapped_outs)

View File

@ -354,12 +354,18 @@ def trace_flex_attention(
score_mod_other_buffers,
mask_mod_other_buffers,
)
# pyrefly: ignore # missing-attribute
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", flex_attention, proxy_args, {}
)
return track_tensor_tree(
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
# pyrefly: ignore # bad-argument-type
example_out,
out_proxy,
constant=None,
# pyrefly: ignore # bad-argument-type
tracer=proxy_mode.tracer,
)
@ -621,6 +627,7 @@ def create_fw_bw_graph(
class FlexAttentionAutogradOp(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(
ctx: Any,
query: Tensor,
@ -1063,6 +1070,7 @@ def trace_flex_attention_backward(
score_mod_other_buffers,
mask_mod_other_buffers,
)
# pyrefly: ignore # missing-attribute
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function",
@ -1072,7 +1080,12 @@ def trace_flex_attention_backward(
name="flex_attention_backward",
)
return track_tensor_tree(
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
# pyrefly: ignore # bad-argument-type
example_out,
out_proxy,
constant=None,
# pyrefly: ignore # bad-argument-type
tracer=proxy_mode.tracer,
)

View File

@ -86,6 +86,7 @@ class InvokeSubgraphHOP(HigherOrderOperator):
return super().__call__(subgraph, identifier, *operands)
# pyrefly: ignore # bad-override
def gen_schema(self, subgraph, identifier, *operands):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import (
@ -401,6 +402,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
"""
@staticmethod
# pyrefly: ignore # bad-override
def forward(
ctx,
subgraph,
@ -477,6 +479,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
for tangent in filtered_grad_outs:
metadata = extract_tensor_metadata(tangent)
metadata._flatten_into(tangent_metadata, fake_mode, state)
# pyrefly: ignore # bad-assignment
tangent_metadata = tuple(tangent_metadata)
# bw_graph is a joint graph with signature (*primals_and_tangents) and

View File

@ -197,6 +197,7 @@ def create_hop_fw_bw(
class LocalMapAutogradOp(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(
ctx: Any,
fw_gm: GraphModule,
@ -243,6 +244,7 @@ class LocalMapAutogradOp(torch.autograd.Function):
)
for i, meta in ctx.expected_tangent_metadata.items():
# pyrefly: ignore # bad-argument-type
grads[i] = coerce_to_expected_memory_format(grads[i], meta)
grad_ins = local_map_hop(ctx.bw_gm, *saved_activations, *grads)

View File

@ -125,6 +125,7 @@ def map(
class MapAutogradOp(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, f, num_mapped_args, *flat_args):
ctx._f = f
ctx._num_mapped_args = num_mapped_args

View File

@ -241,6 +241,7 @@ class ScanOp(HigherOrderOperator):
validate_subgraph_args_types(additional_inputs)
return super().__call__(combine_fn, init, xs, additional_inputs)
# pyrefly: ignore # bad-override
def gen_schema(self, combine_fn, init, xs, additional_inputs):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import materialize_as_graph
@ -448,6 +449,7 @@ class ScanAutogradOp(torch.autograd.Function):
"""
@staticmethod
# pyrefly: ignore # bad-override
def forward(
ctx,
hop_partitioned_graph,

View File

@ -292,6 +292,7 @@ def generate_ttir(
ordered_args[name] = 2
elif (
stable_meta := maybe_unpack_tma_stable_metadata(
# pyrefly: ignore # bad-argument-type
tma_descriptor_metadata.get(name, None)
)
) is not None:
@ -425,6 +426,7 @@ def generate_ttir(
specialize_value=not kp.do_not_specialize,
align=not kp.do_not_specialize_on_alignment,
)
# pyrefly: ignore # unsupported-operation
attrvals.append(spec[1])
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
@ -443,6 +445,7 @@ def generate_ttir(
def get_signature_value(idx: int, arg: Any) -> str:
if kernel.params[idx].is_constexpr:
return "constexpr"
# pyrefly: ignore # not-callable
return mangle_type(arg)
else:
@ -815,6 +818,7 @@ def get_tma_stores(
for op in op_list:
if op.name == "tt.call":
assert op.fn_call_name in functions
# pyrefly: ignore # bad-argument-type
tma_stores = get_tma_stores(functions, op.fn_call_name)
for i, inp in enumerate(op.args):
if Param(idx=i) in tma_stores:
@ -895,7 +899,11 @@ def analyze_kernel_mutations(
if op.name == "tt.call":
assert op.fn_call_name in functions
mutations = analyze_kernel_mutations(
functions, op.fn_call_name, len(op.args)
# pyrefly: ignore # bad-argument-type
functions,
# pyrefly: ignore # bad-argument-type
op.fn_call_name,
len(op.args),
)
stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
else:
@ -948,6 +956,7 @@ def identify_mutated_tensors(
assert functions is not None
kernel_name = next(iter(functions.keys()))
# Triton codegen modifies the name
# pyrefly: ignore # missing-attribute
assert kernel.fn.__name__ in kernel_name
# Reset the cache between top level invocations
# The cache for analyze kernel mutations is mainly used for cycle
@ -1051,7 +1060,11 @@ def triton_kernel_wrapper_mutation_dense(
grid_fn = grid[0]
else:
fn_name, code = user_defined_kernel_grid_fn_code(
kernel.fn.__name__, kernel.configs, grid
# pyrefly: ignore # missing-attribute
kernel.fn.__name__,
# pyrefly: ignore # missing-attribute
kernel.configs,
grid,
)
namespace: dict[str, Any] = {}
exec(code, namespace)
@ -1100,6 +1113,7 @@ def triton_kernel_wrapper_mutation_dense(
# avoid mutating the original inputs
kwargs = kwargs.copy()
constant_args = constant_args.copy()
# pyrefly: ignore # missing-attribute
for name in kernel.arg_names:
if name in kwargs:
args.append(kwargs.pop(name))
@ -1108,6 +1122,7 @@ def triton_kernel_wrapper_mutation_dense(
else:
break
# pyrefly: ignore # index-error
kernel[grid_fn](*args, **kwargs, **constant_args)
@ -1513,6 +1528,7 @@ class TritonHOPifier:
assert kernel_idx is None or variable.kernel_idx == kernel_idx
# pyrefly: ignore # bad-assignment
variable.grid = grid
if isinstance(kernel, Autotuner):
@ -2057,6 +2073,7 @@ class TraceableTritonKernelWrapper:
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
else:
assert self.kernel is not None
# pyrefly: ignore # missing-attribute
return self.kernel.run(*args, **kwargs)
def __call__(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
@ -2068,6 +2085,7 @@ class TraceableTritonKernelWrapper:
)
else:
assert self.kernel is not None
# pyrefly: ignore # index-error
return self.kernel[self.grid](*args, **kwargs)
def specialize_symbolic(self, arg: Sequence[Any]) -> Any:

View File

@ -270,6 +270,7 @@ def _set_compilation_env():
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
# once we are confident fx tracing works with dynamo.
torch.fx._symbolic_trace._is_fx_tracing_flag = False
# pyrefly: ignore # bad-assignment
torch._dynamo.config.allow_empty_graphs = True
torch._dynamo.config.capture_scalar_outputs = True
yield
@ -440,6 +441,7 @@ def unique_graph_name_with_root(
) -> tuple[int, str]:
next_name = None
i = 0
# pyrefly: ignore # bad-assignment
while not next_name:
candidate = f"{prefix}_{i}"
if hasattr(root, candidate):
@ -795,6 +797,8 @@ def create_bw_fn(
"""
from torch._functorch.aot_autograd import AOTConfig, create_joint
# pyrefly: ignore # missing-module-attribute
from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad
dummy_aot_config = AOTConfig(
@ -939,6 +943,7 @@ def check_input_alias_and_mutation(
out_out_alias_map,
mutated_inputs,
) = check_input_alias_and_mutation_return_outputs(gm)[:-1]
# pyrefly: ignore # bad-return
return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs

View File

@ -54,6 +54,7 @@ class WhileLoopOp(HigherOrderOperator):
validate_subgraph_args_types(additional_inputs)
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
# pyrefly: ignore # bad-override
def gen_schema(self, cond_fn, body_fn, carried_inputs, additional_inputs):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import materialize_as_graph
@ -430,6 +431,7 @@ def while_loop_tracing(
elif isinstance(x, torch.Tensor):
x = x.clone()
if hasattr(x, "constant") and x.constant is not None:
# pyrefly: ignore # missing-attribute
x.constant = None
return x
@ -452,6 +454,7 @@ def while_loop_tracing(
next_name = None
i = 0
# pyrefly: ignore # bad-assignment
while not next_name:
candidate = f"while_loop_cond_graph_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
@ -696,6 +699,7 @@ class WhileLoopStackOutputOp(HigherOrderOperator):
class WhileLoopAutogradOp(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(
ctx,
cond_fn,
@ -725,6 +729,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
ctx.additional_inputs = additional_inputs
ctx.fw_outputs = fw_outputs
loop_count = None
# pyrefly: ignore # bad-assignment
for out in fw_outputs:
if isinstance(out, torch.Tensor):
if loop_count is not None:
@ -878,6 +883,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
while_loop_op(
cond_gm,
body_gm,
# pyrefly: ignore # bad-argument-type
(
init_idx,
*init_grad_carries,

View File

@ -880,10 +880,14 @@ def logsumexp(
if not isinstance(dim, Iterable):
dim = (dim,)
if self.numel() == 0:
# pyrefly: ignore # no-matching-overload
return torch.sum(torch.exp(self), dim, keepdim).log()
# pyrefly: ignore # bad-argument-type
maxes = torch.amax(torch.real(self), dim, keepdim=True)
maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
# pyrefly: ignore # no-matching-overload
maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
# pyrefly: ignore # no-matching-overload
result = torch.sum(torch.exp(self - maxes), dim, keepdim)
return result.log().add(maxes_squeezed)
@ -1241,10 +1245,12 @@ def copysign(
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
):
if isinstance(b, Number) and isinstance(a, Tensor):
# pyrefly: ignore # bad-argument-type
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
raise RuntimeError(msg)
# pyrefly: ignore # bad-argument-type
return where(signbit(b), neg(abs(a)), abs(a))
@ -1330,10 +1336,13 @@ def float_power(
# Float power has the following contiguous cast behavior to be
# consistent with its C++ impl
# pyrefly: ignore # no-matching-overload
a = _maybe_convert_to_dtype(a, dtype)
# pyrefly: ignore # no-matching-overload
b = _maybe_convert_to_dtype(b, dtype)
a, b = _maybe_broadcast(a, b)
# pyrefly: ignore # bad-return
return pow(a, b)
@ -1375,11 +1384,15 @@ def floor_divide(
):
# Wrap scalars because some references only accept tensor arguments.
if isinstance(a, Number) and isinstance(b, Number):
# pyrefly: ignore # bad-argument-type
a = scalar_tensor(a)
# pyrefly: ignore # bad-argument-type
b = scalar_tensor(b)
elif isinstance(b, Number) and isinstance(a, Tensor):
# pyrefly: ignore # bad-argument-type
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
elif isinstance(a, Number) and isinstance(b, Tensor):
# pyrefly: ignore # bad-argument-type
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
if a.device == torch.device("cpu"):
@ -1856,8 +1869,10 @@ def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberT
# Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
if isinstance(b, TensorLike) and isinstance(a, Number):
# pyrefly: ignore # bad-argument-type
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
elif isinstance(a, TensorLike) and isinstance(b, Number):
# pyrefly: ignore # bad-argument-type
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
# mypy: expected "Tensor"
@ -2333,6 +2348,7 @@ def all(
dim: Optional[DimsType] = None,
keepdim: bool = False,
) -> TensorLikeType:
# pyrefly: ignore # no-matching-overload
result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim))
if a.dtype == torch.uint8:
@ -2850,6 +2866,7 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
# SymInts
example = None
# pyrefly: ignore # bad-assignment
for i, t in enumerate(tensors):
if example is None:
if t.ndim != 1:
@ -3228,6 +3245,7 @@ def _normalize(
mean (Tensor): mean of the tensor along norm_dims.
rstd (Tensor): 1/std of the tensor along norm_dims.
"""
# pyrefly: ignore # no-matching-overload
norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
computation_dtype = utils.get_computation_dtype(a.dtype)
a_acc = _maybe_convert_to_dtype(a, computation_dtype)
@ -3341,6 +3359,7 @@ def native_layer_norm(
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
# therefore we use tuple(normalized_shape)
torch._check(
# pyrefly: ignore # bad-argument-type
weight is None or sym_eq(weight.shape, tuple(normalized_shape)),
lambda: "Expected weight to be of same shape as normalized_shape, but got "
+ "weight of shape "
@ -3349,6 +3368,7 @@ def native_layer_norm(
+ str(normalized_shape),
)
torch._check(
# pyrefly: ignore # bad-argument-type
bias is None or sym_eq(bias.shape, tuple(normalized_shape)),
lambda: "Expected bias to be of same shape as normalized_shape, but got "
+ "bias of shape "
@ -3359,7 +3379,10 @@ def native_layer_norm(
torch._check(
input.ndim >= normalized_ndim
and sym_eq(
input.shape[(input.ndim - normalized_ndim) :], tuple(normalized_shape)
# pyrefly: ignore # bad-argument-type
input.shape[(input.ndim - normalized_ndim) :],
# pyrefly: ignore # bad-argument-type
tuple(normalized_shape),
),
lambda: "Given normalized_shape="
+ str(normalized_shape)
@ -3953,6 +3976,7 @@ def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
@out_wrapper()
def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType:
"""Reference implementation of :func:`torch.roll`."""
# pyrefly: ignore # no-matching-overload
dims = utils.canonicalize_dims(a.ndim, dims)
# ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
if not isinstance(shifts, Iterable):
@ -3965,12 +3989,16 @@ def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLike
# Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
return a.clone()
# pyrefly: ignore # bad-argument-type
if a.dim() == 0 and len(dims) > 0:
raise IndexError(
# pyrefly: ignore # index-error
f"Dimension specified as {dims[0]} but tensor has no dimensions"
)
# pyrefly: ignore # bad-argument-type
len_shifts = len(shifts)
# pyrefly: ignore # bad-argument-type
len_dims = len(dims)
if len_shifts != 1 or len_dims != 1:
if len_shifts == 0:
@ -3978,21 +4006,27 @@ def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLike
# Takes care of the case when dims is not specified (default)
# By default, the tensor is flattened before shifting, after which the original shape is restored
if len_dims == 0 and len_shifts == 1:
# pyrefly: ignore # bad-argument-type
return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
if len_shifts != len_dims:
raise RuntimeError(
f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
)
assert len_dims > 1
# pyrefly: ignore # index-error
tail_shifts = shifts[1:]
# pyrefly: ignore # index-error
tail_dims = dims[1:]
# pyrefly: ignore # index-error
first_dim_rolled = torch.roll(a, (shifts[0],), dims[0])
return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
# This path is taken when only one dimension is rolled
# For example to get `first_dim_rolled` above
# pyrefly: ignore # index-error
dim = dims[0]
size = a.shape[dim]
# pyrefly: ignore # index-error
start = (size - shifts[0]) % size
idx = torch.arange(size, device=a.device)
return a.index_select(dim, torch.fmod(start + idx, size))
@ -4074,7 +4108,9 @@ def softmax(
a_max = amax(a_, dim, keepdim=True)
a_exp = exp(a_ - a_max)
return _maybe_convert_to_dtype(
true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
# pyrefly: ignore # no-matching-overload
true_divide(a_exp, sum(a_exp, dim, keepdim=True)),
result_dtype,
) # type: ignore[return-value]
@ -4251,6 +4287,7 @@ def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType
return prims.squeeze(a, dims) if dims else prims.view_of(a)
ndim = a.ndim
# pyrefly: ignore # no-matching-overload
dim = utils.canonicalize_dims(ndim, dim)
dims = (dim,) if isinstance(dim, Dim) else dim
# Short-circuits if the tensor has no dimensions
@ -4391,6 +4428,7 @@ def hsplit(
if isinstance(indices_or_sections, IntLike):
split_size = indices_or_sections
torch._check(
# pyrefly: ignore # unsupported-operation
(split_size != 0 and a.shape[dim] % split_size == 0),
lambda: (
"torch.hsplit attempted to split along dimension "
@ -4402,6 +4440,7 @@ def hsplit(
+ "!"
),
)
# pyrefly: ignore # bad-argument-type
return tensor_split(a, split_size, dim)
torch._check_type(
@ -4432,6 +4471,7 @@ def vsplit(
if isinstance(indices_or_sections, IntLike):
split_size = indices_or_sections
torch._check(
# pyrefly: ignore # unsupported-operation
(split_size != 0 and a.shape[0] % split_size == 0),
lambda: (
f"torch.vsplit attempted to split along dimension 0"
@ -4442,6 +4482,7 @@ def vsplit(
f"!"
),
)
# pyrefly: ignore # bad-argument-type
return tensor_split(a, split_size, 0)
torch._check_type(
@ -4646,6 +4687,7 @@ def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
raise RuntimeError(
f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
)
# pyrefly: ignore # unsupported-operation
if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
raise RuntimeError(
"torch.dsplit attempted to split along dimension 2, "
@ -5419,6 +5461,7 @@ def logspace(
@overload
# pyrefly: ignore # inconsistent-overload
def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
pass
@ -5845,6 +5888,7 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi
# Since `where` allows type-promotion,
# cast value to correct type before passing to `where`
# pyrefly: ignore # no-matching-overload
value = _maybe_convert_to_dtype(value, a.dtype)
r = torch.where(mask, value, a) # type: ignore[arg-type]
@ -6639,6 +6683,7 @@ def _infer_scalar_type(obj):
# double.
if length == 0:
return torch.get_default_dtype()
# pyrefly: ignore # bad-assignment
for i in range(length):
cur_item = obj[i]
# TODO: test this
@ -6676,6 +6721,7 @@ def _recursive_build(
# torch.Size([1, 2])
return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
elif isinstance(obj, Number):
# pyrefly: ignore # bad-argument-type
return torch.scalar_tensor(obj, dtype=scalarType)
# seq can be a list of tensors

View File

@ -106,11 +106,13 @@ def _resize_fft_input(
if x_sizes[dims[i]] < sizes[i]:
must_copy = True
pad_idx = len(pad_amount) - 2 * dims[i] - 1
# pyrefly: ignore # unsupported-operation
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
if x_sizes[dims[i]] > sizes[i]:
x = x.narrow(dims[i], 0, sizes[i])
# pyrefly: ignore # bad-argument-type
return torch.constant_pad_nd(x, pad_amount) if must_copy else x

View File

@ -216,6 +216,7 @@ def matrix_norm(
# shape
check_is_matrix(A, "linalg.matrix_norm")
# dim
# pyrefly: ignore # no-matching-overload
dim = utils.canonicalize_dims(A.ndim, dim)
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
@ -223,7 +224,9 @@ def matrix_norm(
len(dim) == 2, lambda: f"linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
)
torch._check(
# pyrefly: ignore # index-error
dim[0] != dim[1],
# pyrefly: ignore # index-error
lambda: f"linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
)
# dtype arg
@ -245,6 +248,7 @@ def matrix_norm(
else: # ord == "nuc"
if dtype is not None:
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
# pyrefly: ignore # index-error
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
if keepdim:
@ -268,6 +272,7 @@ def matrix_norm(
if abs_ord == 2.0:
if dtype is not None:
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
# pyrefly: ignore # index-error
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
if keepdim:
@ -275,6 +280,7 @@ def matrix_norm(
result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
return result
else: # 1, -1, inf, -inf
# pyrefly: ignore # bad-unpacking
dim0, dim1 = dim
if abs_ord == float("inf"):
dim0, dim1 = dim1, dim0

View File

@ -142,9 +142,11 @@ def _inplace_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]:
# nb. We use the name of the first argument used in the unary references
@wraps(fn)
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
# pyrefly: ignore # unsupported-operation
a = args[0]
if "inplace" not in kwargs:
kwargs["inplace"] = False
# pyrefly: ignore # unsupported-operation
if kwargs["inplace"]:
torch._check(
"out" not in kwargs,
@ -625,6 +627,7 @@ def smooth_l1_loss(
)
else:
loss = torch.abs(input - target)
# pyrefly: ignore # unsupported-operation
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
return _apply_loss_reduction(loss, reduction)

View File

@ -155,8 +155,10 @@ def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, Numbe
# Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
if isinstance(a, TensorLike) and isinstance(b, Number):
# pyrefly: ignore # bad-argument-type
b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device)
elif isinstance(b, TensorLike) and isinstance(a, Number):
# pyrefly: ignore # bad-argument-type
a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device)
# mypy: expected "Tensor"

View File

@ -130,6 +130,7 @@ def var_decomposition(
else:
raise RuntimeError("correction must be int or float")
# pyrefly: ignore # no-matching-overload
return sum / max(0, denom)

View File

@ -14,6 +14,8 @@ import torch
_IS_MONKEYTYPE_INSTALLED = True
try:
import monkeytype # type: ignore[import]
# pyrefly: ignore # import-error
from monkeytype import trace as monkeytype_trace
from monkeytype.config import _startswith, LIB_PATHS # type: ignore[import]
from monkeytype.db.base import ( # type: ignore[import]
@ -87,6 +89,7 @@ if _IS_MONKEYTYPE_INSTALLED:
super().__init__(store)
def log(self, trace: CallTrace) -> None:
# pyrefly: ignore # missing-attribute
self.traces.append(trace)
class JitTypeTraceStore(CallTraceStore):
@ -148,6 +151,7 @@ if _IS_MONKEYTYPE_INSTALLED:
def trace_logger(self) -> JitTypeTraceStoreLogger:
"""Return a JitCallTraceStoreLogger that logs to the configured trace store."""
# pyrefly: ignore # bad-argument-count
return JitTypeTraceStoreLogger(self.trace_store())
def trace_store(self) -> CallTraceStore:

View File

@ -747,6 +747,7 @@ def get_overload_annotations(mod, jit_ignored_properties):
if method_overloads is None:
continue
# pyrefly: ignore # missing-attribute
if item.__func__ in method_overloads:
raise RuntimeError(
_jit_internal.get_overload_no_implementation_error_message(

View File

@ -545,6 +545,7 @@ if _enabled:
#
# This ensures that if we use the attr again in `__init__`, it
# will look like the actual value, not an instance of Attribute.
# pyrefly: ignore # invalid-argument
if isinstance(value, Attribute):
# NB: Ensure that we set __annotations__ on the specific
# class in question, and not on a superclass (which would
@ -656,6 +657,7 @@ if _enabled:
# Finalize the ScriptModule: replace the nn.Module state with our
# custom implementations and flip the _initializing bit.
# pyrefly: ignore # missing-attribute
RecursiveScriptModule._finalize_scriptmodule(script_module)
return script_module
@ -929,6 +931,7 @@ if _enabled:
# Don't do anything here, we'll initialize the ScriptModule below
return
# pyrefly: ignore # missing-attribute
return RecursiveScriptModule._construct(
self._c._replicate_for_data_parallel(), init_fn
)
@ -938,6 +941,7 @@ if _enabled:
# This is because `super().foo()` does not use
# `__getattr__` to look up `foo`. So we need to make each method available on
# the ScriptModule manually.
# pyrefly: ignore # missing-attribute
for name, item in RecursiveScriptModule.__dict__.items():
if not callable(item) and not isinstance(item, property):
continue
@ -1006,6 +1010,7 @@ if _enabled:
if name.startswith("__") or name.endswith("_call_impl"):
continue
if (
# pyrefly: ignore # missing-attribute
name not in RecursiveScriptModule.__dict__
and name not in _compiled_methods_allowlist
):
@ -1038,6 +1043,7 @@ def call_prepare_scriptable_func_impl(obj, memo):
return memo[id(obj)]
obj = (
# pyrefly: ignore # not-callable
obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj
) # type: ignore[operator]
# Record obj in memo to avoid infinite recursion in the case of cycles in the module
@ -1135,6 +1141,7 @@ def _script_impl(
# the provide example inputs. This logs all the traces in type_trace_db
type_trace_db = JitTypeTraceStore()
if monkeytype_trace:
# pyrefly: ignore # bad-argument-count
monkeytype_config = JitTypeTraceConfig(type_trace_db)
with monkeytype_trace(monkeytype_config):
if isinstance(example_inputs, dict):

View File

@ -166,11 +166,25 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
cu = torch._C.CompilationUnit()
if isinstance(f, (str, os.PathLike)):
cpp_module = torch._C.import_ir_module(
cu, os.fspath(f), map_location, _extra_files, _restore_shapes
# pyrefly: ignore # no-matching-overload, bad-argument-count
cu,
# pyrefly: ignore # no-matching-overload
os.fspath(f),
map_location,
_extra_files,
# pyrefly: ignore # bad-argument-count
_restore_shapes,
) # type: ignore[call-arg]
else:
cpp_module = torch._C.import_ir_module_from_buffer(
cu, f.read(), map_location, _extra_files, _restore_shapes
# pyrefly: ignore # missing-attribute, bad-argument-count
cu,
# pyrefly: ignore # missing-attribute
f.read(),
map_location,
_extra_files,
# pyrefly: ignore # bad-argument-count
_restore_shapes,
) # type: ignore[call-arg]
# TODO: Pretty sure this approach loses ConstSequential status and such
@ -196,6 +210,7 @@ def validate_map_location(map_location=None):
def jit_module_from_flatbuffer(f):
if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f)
return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
else:
@ -245,6 +260,7 @@ def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
extra_files = {}
if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f)
torch._C._save_jit_module(m._c, f, extra_files)
else:

View File

@ -561,6 +561,7 @@ def cat(tensors: list[list[int]], dim: int):
for i in range(len(tensors)):
tensor = tensors[i]
if not should_skip(tensor):
# pyrefly: ignore # bad-argument-type
check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
cat_dim_size = cat_dim_size + tensor[dim]

View File

@ -169,6 +169,7 @@ def _clone_inputs(args):
else:
return a.clone(memory_format=torch.preserve_format)
# pyrefly: ignore # missing-attribute
return function._nested_map(
lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors"
)(args)
@ -335,6 +336,7 @@ def _check_trace(
if is_trace_module:
copied_dict = {}
# pyrefly: ignore # missing-attribute
for name, data in inputs.items():
copied_dict[name] = _clone_inputs(data)
check_mod = torch.jit.trace_module(
@ -739,6 +741,7 @@ def _trace_impl(
example_inputs = (example_inputs,)
# done primarily so that weird iterables fail here and not pybind11 code
elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple):
# pyrefly: ignore # bad-argument-type
example_inputs = tuple(example_inputs)
var_lookup_fn = _create_interpreter_name_lookup_fn(0)
@ -765,6 +768,7 @@ def _trace_impl(
traced = torch._C._create_function_from_trace(
name,
func,
# pyrefly: ignore # bad-argument-type
example_inputs,
var_lookup_fn,
strict,

View File

@ -98,6 +98,7 @@ class EvalEnv:
def __init__(self, rcb):
self.rcb = rcb
if torch.distributed.rpc.is_available():
# pyrefly: ignore # unsupported-operation
self.env["RRef"] = RRef
def __getitem__(self, name):

View File

@ -115,6 +115,7 @@ node_start_tokens = {
ast.Continue: "continue",
}
# pyrefly: ignore # no-matching-overload
pretty_node_names.update(
{
ast.AsyncFunctionDef: "async function definitions",
@ -125,6 +126,7 @@ pretty_node_names.update(
}
)
# pyrefly: ignore # no-matching-overload
node_start_tokens.update(
{
ast.AsyncFunctionDef: "async def",
@ -135,6 +137,7 @@ node_start_tokens.update(
}
)
# pyrefly: ignore # no-matching-overload
pretty_node_names.update(
{
ast.AnnAssign: "annotated assignments",
@ -859,6 +862,7 @@ class ExprBuilder(Builder):
ast.RShift: ">>",
}
# pyrefly: ignore # unsupported-operation
binop_map[ast.MatMult] = "@"
unop_map = {
@ -1220,6 +1224,7 @@ class ExprBuilder(Builder):
s += "{}"
args.append(build_expr(ctx, value.value))
elif isinstance(value, ast.Constant):
# pyrefly: ignore # unsupported-operation
s += value.value
else:
raise NotSupportedError(r, "Unsupported value in JoinedStr")

View File

@ -44,10 +44,13 @@ def _load_for_lite_interpreter(f, map_location=None):
map_location = validate_map_location(map_location)
if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
else:
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
f.read(), map_location
# pyrefly: ignore # missing-attribute
f.read(),
map_location,
)
return LiteScriptModule(cpp_module)
@ -103,8 +106,10 @@ def _get_model_bytecode_version(f_input) -> int:
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
return torch._C._get_model_bytecode_version(os.fspath(f_input))
else:
# pyrefly: ignore # missing-attribute
return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
@ -135,8 +140,10 @@ def _get_mobile_model_contained_types(f_input) -> int:
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
else:
# pyrefly: ignore # missing-attribute
return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read())
@ -161,11 +168,18 @@ def _backport_for_mobile(f_input, f_output, to_version):
isinstance(f_output, (str, os.PathLike))
):
return torch._C._backport_for_mobile(
os.fspath(f_input), os.fspath(f_output), to_version
# pyrefly: ignore # no-matching-overload
os.fspath(f_input),
# pyrefly: ignore # no-matching-overload
os.fspath(f_output),
to_version,
)
else:
return torch._C._backport_for_mobile_from_buffer(
f_input.read(), str(f_output), to_version
# pyrefly: ignore # missing-attribute
f_input.read(),
str(f_output),
to_version,
)
@ -184,10 +198,13 @@ def _backport_for_mobile_to_buffer(f_input, to_version):
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
else:
return torch._C._backport_for_mobile_from_buffer_to_buffer(
f_input.read(), to_version
# pyrefly: ignore # missing-attribute
f_input.read(),
to_version,
)
@ -227,6 +244,8 @@ def _get_model_ops_and_info(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
return torch._C._get_model_ops_and_info(os.fspath(f_input))
else:
# pyrefly: ignore # missing-attribute
return torch._C._get_model_ops_and_info(f_input.read())

View File

@ -261,6 +261,7 @@ def _get_global_builtins():
magic_methods_rows = []
for fn, magic_method in magic_methods:
# pyrefly: ignore # bad-argument-type
magic_methods_rows.append(f'"{fn}", "``{magic_method}``"')
schematized_ops = []
@ -279,6 +280,7 @@ def _get_global_builtins():
table_row = (
f'":external+python:py:obj:`{fn}`", "{schemaless_op_explanations[fn]}"'
)
# pyrefly: ignore # bad-argument-type
schemaless_ops.append(table_row)
schematized_ops_str = "\n".join(schematized_ops)

View File

@ -78,6 +78,7 @@ def _adjust_lr(
A, B = param_shape[:2]
if adjust_lr_fn is None or adjust_lr_fn == "original":
# pyrefly: ignore # no-matching-overload
adjusted_ratio = math.sqrt(max(1, A / B))
elif adjust_lr_fn == "match_rms_adamw":
adjusted_ratio = 0.2 * math.sqrt(max(A, B))

View File

@ -415,6 +415,7 @@ def _single_tensor_adam(
if weight_decay.requires_grad:
grad = grad.addcmul_(param.clone(), weight_decay)
else:
# pyrefly: ignore # bad-argument-type
grad = grad.add(param, alpha=weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
@ -444,6 +445,7 @@ def _single_tensor_adam(
device_beta1 = beta1
# Decay the first and second moment running average coefficient
# pyrefly: ignore # no-matching-overload
exp_avg.lerp_(grad, 1 - device_beta1)
# Nested if is necessary to bypass jitscript rules
@ -692,6 +694,7 @@ def _multi_tensor_adam(
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
)
# pyrefly: ignore # no-matching-overload
torch._foreach_mul_(device_exp_avg_sqs, beta2)
# Due to the strictness of the _foreach_addcmul API, we can't have a single

View File

@ -263,6 +263,7 @@ def _single_tensor_asgd(
ax.copy_(param)
if capturable:
# pyrefly: ignore # unsupported-operation
eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
else:

View File

@ -113,9 +113,11 @@ def _strong_wolfe(
# compute new trial value
t = _cubic_interpolate(
# pyrefly: ignore # index-error
bracket[0],
bracket_f[0],
bracket_gtd[0], # type: ignore[possibly-undefined]
# pyrefly: ignore # index-error
bracket[1],
bracket_f[1],
bracket_gtd[1],
@ -151,6 +153,7 @@ def _strong_wolfe(
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
# Armijo condition not satisfied or not lower than lowest point
# pyrefly: ignore # unsupported-operation
bracket[high_pos] = t
bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
@ -160,14 +163,17 @@ def _strong_wolfe(
if abs(gtd_new) <= -c2 * gtd:
# Wolfe conditions satisfied
done = True
# pyrefly: ignore # index-error
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
# old high becomes new low
# pyrefly: ignore # unsupported-operation
bracket[high_pos] = bracket[low_pos]
bracket_f[high_pos] = bracket_f[low_pos]
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
bracket_gtd[high_pos] = bracket_gtd[low_pos]
# new point becomes new low
# pyrefly: ignore # unsupported-operation
bracket[low_pos] = t
bracket_f[low_pos] = f_new
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
@ -252,6 +258,7 @@ class LBFGS(Optimizer):
def _numel(self):
if self._numel_cache is None:
# pyrefly: ignore # bad-assignment
self._numel_cache = sum(
2 * p.numel() if torch.is_complex(p) else p.numel()
for p in self._params

View File

@ -1665,6 +1665,7 @@ class ReduceLROnPlateau(LRScheduler):
self.default_min_lr = None
self.min_lrs = list(min_lr)
else:
# pyrefly: ignore # bad-assignment
self.default_min_lr = min_lr
self.min_lrs = [min_lr] * len(optimizer.param_groups)
@ -1724,6 +1725,7 @@ class ReduceLROnPlateau(LRScheduler):
"of the `optimizer` param groups."
)
else:
# pyrefly: ignore # bad-assignment
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
for i, param_group in enumerate(self.optimizer.param_groups):
@ -1903,10 +1905,13 @@ class CyclicLR(LRScheduler):
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
# pyrefly: ignore # bad-assignment
step_size_up = float(step_size_up)
step_size_down = (
# pyrefly: ignore # bad-assignment
float(step_size_down) if step_size_down is not None else step_size_up
)
# pyrefly: ignore # unsupported-operation
self.total_size = step_size_up + step_size_down
self.step_ratio = step_size_up / self.total_size

View File

@ -62,6 +62,7 @@ def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
import torch._dynamo
# pyrefly: ignore # unsupported-operation
self = cast(Optimizer, args[0]) # assume first positional arg is `self`
prev_grad = torch.is_grad_enabled()
try:
@ -135,11 +136,13 @@ def _disable_dynamo_if_unsupported(
if torch.compiler.is_compiling() and (
not kwargs.get("capturable", False)
and has_state_steps
# pyrefly: ignore # unsupported-operation
and (arg := args[state_steps_ind])
and isinstance(arg, Sequence)
and arg[0].is_cuda
or (
"state_steps" in kwargs
# pyrefly: ignore # unsupported-operation
and (kwarg := kwargs["state_steps"])
and isinstance(kwarg, Sequence)
and kwarg[0].is_cuda
@ -359,14 +362,18 @@ class Optimizer:
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
# pyrefly: ignore # not-a-type
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
_optimizer_state_dict_post_hooks: (
# pyrefly: ignore # not-a-type
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
)
_optimizer_load_state_dict_pre_hooks: (
# pyrefly: ignore # not-a-type
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
)
_optimizer_load_state_dict_post_hooks: (
# pyrefly: ignore # not-a-type
'OrderedDict[int, Callable[["Optimizer"], None]]'
)
@ -391,6 +398,7 @@ class Optimizer:
self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict)
self.param_groups: list[dict[str, Any]] = []
# pyrefly: ignore # no-matching-overload
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
@ -514,6 +522,7 @@ class Optimizer:
f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
)
# pyrefly: ignore # invalid-param-spec
out = func(*args, **kwargs)
self._optimizer_step_code()
@ -949,7 +958,14 @@ class Optimizer:
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
return Optimizer._process_value_according_to_param_policy(
param, value, param_id, param_groups, key
# pyrefly: ignore # bad-argument-type
param,
value,
# pyrefly: ignore # bad-argument-type
param_id,
# pyrefly: ignore # bad-argument-type
param_groups,
key,
)
elif isinstance(value, dict):
return {
@ -960,6 +976,7 @@ class Optimizer:
}
elif isinstance(value, Iterable):
return type(value)(
# pyrefly: ignore # bad-argument-count
_cast(param, v, param_id=param_id, param_groups=param_groups)
for v in value
) # type: ignore[call-arg]

View File

@ -322,6 +322,7 @@ def _single_tensor_radam(
rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
def _compute_rect():
# pyrefly: ignore # unsupported-operation
return (
(rho_t - 4)
* (rho_t - 2)
@ -336,6 +337,7 @@ def _single_tensor_radam(
else:
exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
# pyrefly: ignore # unsupported-operation
return (bias_correction2**0.5) / exp_avg_sq_sqrt
# Compute the variance rectification term and update parameters accordingly

View File

@ -337,6 +337,7 @@ def _single_tensor_sgd(
if not torch.jit.is_scripting():
lr = _to_scalar(lr)
# pyrefly: ignore # bad-assignment
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
@ -347,6 +348,7 @@ def _single_tensor_sgd(
# usually this is the differentiable path, which is why the param.clone() is needed
grad = grad.addcmul_(param.clone(), weight_decay)
else:
# pyrefly: ignore # bad-argument-type
grad = grad.add(param, alpha=weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
@ -370,6 +372,7 @@ def _single_tensor_sgd(
if lr.requires_grad:
param.addcmul_(grad, lr, value=-1)
else:
# pyrefly: ignore # bad-argument-type
param.add_(grad, alpha=-lr)
else:
param.add_(grad, alpha=-lr)
@ -430,10 +433,12 @@ def _multi_tensor_sgd(
all_states_with_momentum_buffer = True
for i in range(len(device_momentum_buffer_list)):
# pyrefly: ignore # index-error
if device_momentum_buffer_list[i] is None:
all_states_with_momentum_buffer = False
break
else:
# pyrefly: ignore # index-error
bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
if all_states_with_momentum_buffer:
@ -441,12 +446,15 @@ def _multi_tensor_sgd(
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
else:
bufs = []
# pyrefly: ignore # bad-assignment
for i in range(len(device_momentum_buffer_list)):
# pyrefly: ignore # index-error
if device_momentum_buffer_list[i] is None:
buf = device_momentum_buffer_list[i] = momentum_buffer_list[
indices[i]
] = device_grads[i].detach().clone()
else:
# pyrefly: ignore # index-error
buf = cast(Tensor, device_momentum_buffer_list[i])
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)

View File

@ -249,11 +249,13 @@ class AveragedModel(Module):
def update_parameters(self, model: Module):
"""Update model parameters."""
self_param = (
# pyrefly: ignore # bad-argument-type
itertools.chain(self.module.parameters(), self.module.buffers())
if self.use_buffers
else self.parameters()
)
model_param = (
# pyrefly: ignore # bad-argument-type
itertools.chain(model.parameters(), model.buffers())
if self.use_buffers
else model.parameters()
@ -300,8 +302,11 @@ class AveragedModel(Module):
for p_averaged, p_model in zip( # type: ignore[assignment]
self_param_detached, model_param_detached
):
# pyrefly: ignore # missing-attribute
n_averaged = self.n_averaged.to(p_averaged.device)
# pyrefly: ignore # missing-attribute
p_averaged.detach().copy_(
# pyrefly: ignore # missing-attribute, bad-argument-type
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
)
@ -489,12 +494,14 @@ class SWALR(LRScheduler):
step = self._step_count - 1
if self.anneal_epochs == 0:
step = max(1, step)
# pyrefly: ignore # no-matching-overload
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
prev_alpha = self.anneal_func(prev_t)
prev_lrs = [
self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
for group in self.optimizer.param_groups
]
# pyrefly: ignore # no-matching-overload
t = max(0, min(1, step / max(1, self.anneal_epochs)))
alpha = self.anneal_func(t)
return [