mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Pyrefly suppressions 4/n (#164615)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1 after: 0 errors (2,753 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615 Approved by: https://github.com/oulgen
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							4bd1505f84
						
					
				
				
					commit
					4ab847bbc7
				
			| @ -29,15 +29,11 @@ project-excludes = [ | ||||
|   "torch/fx/**", | ||||
|   "torch/distributions/**", | ||||
|   "torch/onnx/**", | ||||
|   "torch/_refs/**", | ||||
|   "torch/_export/**", | ||||
|   "torch/jit/**", | ||||
|   "torch/optim/**", | ||||
|   "torch/_higher_order_ops/**", | ||||
|   # formatting issues | ||||
|   "torch/linalg/__init__.py", | ||||
|   "torch/package/importer.py", | ||||
|   "torch/package/_package_pickler.py", | ||||
|   "torch/jit/annotations.py", | ||||
|   # ==== | ||||
|   "benchmarks/instruction_counts/main.py", | ||||
|   "benchmarks/instruction_counts/definitions/setup.py", | ||||
|  | ||||
| @ -1097,6 +1097,7 @@ class TS2FXGraphConverter: | ||||
|  | ||||
|             # Update the value of loop local variables. | ||||
|             if node.outputsSize() >= 1: | ||||
|                 # pyrefly: ignore  # bad-assignment | ||||
|                 for i, outp in enumerate(node.outputs()): | ||||
|                     output_name = outp.debugName() | ||||
|                     self.name_to_node[output_name] = self.fx_graph.call_function( | ||||
| @ -1109,6 +1110,7 @@ class TS2FXGraphConverter: | ||||
|                     fx_block_args[i] = self.name_to_node[output_name] | ||||
|  | ||||
|             # Update the value of global variables, whose values are modified inplace. | ||||
|             # pyrefly: ignore  # bad-assignment | ||||
|             for i, name in enumerate( | ||||
|                 subgraph_converter.name_update_from_subblock_to_parent | ||||
|             ): | ||||
|  | ||||
| @ -132,6 +132,7 @@ def _make_export_case(m, name, configs): | ||||
|             m.__doc__ is not None | ||||
|         ), f"Could not find description or docstring for export case: {m}" | ||||
|         configs = {**configs, "description": m.__doc__} | ||||
|     # pyrefly: ignore  # bad-argument-type | ||||
|     return ExportCase(**{**configs, "model": m, "name": name}) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -3,10 +3,12 @@ import torch | ||||
|  | ||||
| class MyAutogradFunction(torch.autograd.Function): | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def forward(ctx, x): | ||||
|         return x.clone() | ||||
|  | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def backward(ctx, grad_output): | ||||
|         return grad_output + 1 | ||||
|  | ||||
|  | ||||
| @ -39,6 +39,7 @@ def get_class_if_classified_error(e: Exception) -> Optional[str]: | ||||
|         TorchRuntimeError: None, | ||||
|     } | ||||
|     if type(e) in _ALLOW_LIST: | ||||
|         # pyrefly: ignore  # index-error | ||||
|         attr_name = _ALLOW_LIST[type(e)] | ||||
|         if attr_name is None: | ||||
|             return ALWAYS_CLASSIFIED | ||||
|  | ||||
| @ -101,6 +101,7 @@ class _KeyPathTrie: | ||||
|             assert len(kp) > 0 | ||||
|             k, *kp = kp  # type: ignore[assignment] | ||||
|             node = node[k] | ||||
|         # pyrefly: ignore  # bad-return | ||||
|         return node, kp | ||||
|  | ||||
|  | ||||
| @ -139,6 +140,7 @@ def key_path_to_source( | ||||
|         source: Source = LocalSource("args") | ||||
|     else: | ||||
|         source, kp = sourced_prefixes.get(kp) | ||||
|     # pyrefly: ignore  # bad-assignment | ||||
|     for k in kp: | ||||
|         if isinstance(k, SequenceKey): | ||||
|             source = GetItemSource(source, k.idx) | ||||
| @ -354,10 +356,12 @@ def _override_builtin_ops(): | ||||
|     original_min = builtins.min | ||||
|     original_pow = math.pow | ||||
|  | ||||
|     # pyrefly: ignore  # bad-assignment | ||||
|     builtins.max = functools.partial( | ||||
|         _tensor_min_max, real_callable=original_max, tensor_callable=torch.maximum | ||||
|     ) | ||||
|  | ||||
|     # pyrefly: ignore  # bad-assignment | ||||
|     builtins.min = functools.partial( | ||||
|         _tensor_min_max, real_callable=original_min, tensor_callable=torch.minimum | ||||
|     ) | ||||
| @ -1083,6 +1087,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): | ||||
|  | ||||
|                 def run(): | ||||
|                     # Run sequence. | ||||
|                     # pyrefly: ignore  # index-error | ||||
|                     t = args[0] | ||||
|                     for _method, _args in sequence: | ||||
|                         t = _method(t, *_args) | ||||
|  | ||||
| @ -188,6 +188,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): | ||||
|             self.callback = callback | ||||
|             self.node: torch.fx.Node = next(iter(gm.graph.nodes)) | ||||
|  | ||||
|         # pyrefly: ignore  # bad-override | ||||
|         def placeholder( | ||||
|             self, | ||||
|             target: str,  # type: ignore[override] | ||||
| @ -439,6 +440,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): | ||||
|         ) | ||||
|         self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode | ||||
|         interpreter = self.ExportInterpreter(self, graph_module) | ||||
|         # pyrefly: ignore  # bad-assignment | ||||
|         prev_interpreter, self.interpreter = ( | ||||
|             self.interpreter, | ||||
|             torch.fx.Interpreter(  # type: ignore[assignment] | ||||
|  | ||||
| @ -32,6 +32,7 @@ def _node_metadata_hook( | ||||
|     that nodes being added are only call_function nodes, and copies over the | ||||
|     first argument node's nn_module_stack. | ||||
|     """ | ||||
|     # pyrefly: ignore  # bad-assignment | ||||
|     fake_mode = fake_mode or contextlib.nullcontext() | ||||
|  | ||||
|     assert node.op == "call_function" and callable(node.target), ( | ||||
| @ -47,6 +48,7 @@ def _node_metadata_hook( | ||||
|         fake_args, fake_kwargs = pytree.tree_map_only( | ||||
|             torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs) | ||||
|         ) | ||||
|         # pyrefly: ignore  # bad-context-manager | ||||
|         with fake_mode, enable_python_dispatcher(): | ||||
|             fake_res = node.target(*fake_args, **fake_kwargs) | ||||
|         node.meta["val"] = fake_res | ||||
| @ -81,7 +83,9 @@ def _node_metadata_hook( | ||||
|     node.meta["torch_fn"] = node.meta.get( | ||||
|         "torch_fn", | ||||
|         ( | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             f"{node.target.__name__}_0", | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             f"{node.target.__class__.__name__}.{node.target.__name__}", | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
| @ -567,6 +567,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule): | ||||
|     quantized = False | ||||
|  | ||||
|     last_quantized_node = None | ||||
|     # pyrefly: ignore  # bad-assignment | ||||
|     for node in gm.graph.nodes: | ||||
|         if isinstance(node.target, OpOverload): | ||||
|             with gm.graph.inserting_before(node): | ||||
| @ -629,6 +630,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule): | ||||
|                     attr_names_to_clean.add(k) | ||||
|                 if k == "_buffers": | ||||
|                     buffer_name_to_clean = set() | ||||
|                     # pyrefly: ignore  # missing-attribute | ||||
|                     for b_name, b_value in v.items(): | ||||
|                         if isinstance(b_value, torch.Tensor) and b_value.dtype in [ | ||||
|                             torch.qint8, | ||||
| @ -636,6 +638,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule): | ||||
|                         ]: | ||||
|                             buffer_name_to_clean.add(b_name) | ||||
|                     for b_name in buffer_name_to_clean: | ||||
|                         # pyrefly: ignore  # missing-attribute | ||||
|                         v.pop(b_name, None) | ||||
|             for attr_name in attr_names_to_clean: | ||||
|                 delattr(submod, attr_name) | ||||
|  | ||||
| @ -35,6 +35,7 @@ def _replace_with_hop_helper( | ||||
|         ) | ||||
|         call_func_node.meta["torch_fn"] = ( | ||||
|             f"{wrap_hoo.__name__}", | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}", | ||||
|         ) | ||||
|         if isinstance(output_args, (tuple, list)): | ||||
|  | ||||
| @ -54,6 +54,7 @@ def _postprocess_serialized_shapes( | ||||
|         ) | ||||
|         for k, v in sorted(dims.items()) | ||||
|     } | ||||
|     # pyrefly: ignore  # bad-argument-type | ||||
|     spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims) | ||||
|     if to_dict: | ||||
|         return _dataclass_to_dict(spec) | ||||
| @ -183,6 +184,7 @@ def _dump_dynamic_shapes( | ||||
|     kwargs = kwargs or {} | ||||
|     if isinstance(dynamic_shapes, dict): | ||||
|         dynamic_shapes = dynamic_shapes.values()  # type: ignore[assignment] | ||||
|     # pyrefly: ignore  # bad-assignment, bad-argument-type | ||||
|     dynamic_shapes = tuple(dynamic_shapes) | ||||
|     combined_args = tuple(args) + tuple(kwargs.values()) | ||||
|  | ||||
|  | ||||
| @ -623,7 +623,9 @@ class _Commit: | ||||
| def update_schema(): | ||||
|     import importlib.resources | ||||
|  | ||||
|     # pyrefly: ignore  # bad-argument-type | ||||
|     if importlib.resources.is_resource(__package__, "schema.yaml"): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         content = importlib.resources.read_text(__package__, "schema.yaml") | ||||
|         match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) | ||||
|         _check(match is not None, "checksum not found in schema.yaml") | ||||
| @ -631,7 +633,9 @@ def update_schema(): | ||||
|         checksum_head = match.group(1) | ||||
|  | ||||
|         thrift_content = importlib.resources.read_text( | ||||
|             __package__, "export_schema.thrift" | ||||
|             # pyrefly: ignore  # bad-argument-type | ||||
|             __package__, | ||||
|             "export_schema.thrift", | ||||
|         ) | ||||
|         match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content) | ||||
|         _check(match is not None, "checksum not found in export_schema.thrift") | ||||
| @ -654,7 +658,9 @@ def update_schema(): | ||||
|  | ||||
|     src, cpp_header, thrift_schema = _staged_schema() | ||||
|     additions, subtractions = _diff_schema(dst, src) | ||||
|     # pyrefly: ignore  # missing-attribute | ||||
|     yaml_path = __package__.replace(".", "/") + "/schema.yaml" | ||||
|     # pyrefly: ignore  # missing-attribute | ||||
|     thrift_schema_path = __package__.replace(".", "/") + "/export_schema.thrift" | ||||
|     torch_prefix = "torch/" | ||||
|     assert yaml_path.startswith(torch_prefix)  # sanity check | ||||
|  | ||||
| @ -383,6 +383,7 @@ def _reconstruct_fake_tensor( | ||||
|     fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta) | ||||
|     if is_parameter: | ||||
|         fake_tensor = torch.nn.Parameter(fake_tensor)  # type: ignore[assignment] | ||||
|     # pyrefly: ignore  # bad-return | ||||
|     return fake_tensor | ||||
|  | ||||
|  | ||||
| @ -2740,6 +2741,7 @@ class GraphModuleDeserializer(metaclass=Final): | ||||
|                     serialized_node.metadata | ||||
|                 ) | ||||
|                 assert arg is not None | ||||
|                 # pyrefly: ignore  # bad-argument-type | ||||
|                 self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata) | ||||
|                 fx_node.meta["val"] = tuple(meta_val) | ||||
|                 self.serialized_name_to_node[fx_node.name] = fx_node | ||||
| @ -3165,6 +3167,7 @@ def _dict_to_dataclass(cls, data): | ||||
|         _value = next(iter(data.values())) | ||||
|         assert isinstance(_type, str) | ||||
|         field_type = cls.__annotations__[_type] | ||||
|         # pyrefly: ignore  # missing-attribute | ||||
|         return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) | ||||
|     elif dataclasses.is_dataclass(cls): | ||||
|         fields = {} | ||||
| @ -3471,18 +3474,23 @@ def _canonicalize_graph( | ||||
|         n.metadata.clear() | ||||
|  | ||||
|     # Stage 4: Aggregate values. | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     sorted_tensor_values = dict( | ||||
|         sorted(graph.tensor_values.items(), key=operator.itemgetter(0)) | ||||
|     ) | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     sorted_sym_int_values = dict( | ||||
|         sorted(graph.sym_int_values.items(), key=operator.itemgetter(0)) | ||||
|     ) | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     sorted_sym_float_values = dict( | ||||
|         sorted(graph.sym_float_values.items(), key=operator.itemgetter(0)) | ||||
|     ) | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     sorted_sym_bool_values = dict( | ||||
|         sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0)) | ||||
|     ) | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     sorted_custom_obj_values = dict( | ||||
|         sorted(graph.custom_obj_values.items(), key=operator.itemgetter(0)) | ||||
|     ) | ||||
| @ -3539,6 +3547,7 @@ def canonicalize( | ||||
|         ExportedProgram: The canonicalized exported program. | ||||
|     """ | ||||
|     ep = copy.deepcopy(ep) | ||||
|     # pyrefly: ignore  # annotation-mismatch | ||||
|     constants: set[str] = constants or set() | ||||
|  | ||||
|     opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0))) | ||||
|  | ||||
| @ -34,7 +34,7 @@ from torch.fx._pytree import ( | ||||
|     _deregister_pytree_flatten_spec, | ||||
|     register_pytree_flatten_spec, | ||||
| ) | ||||
| from torch.utils._pytree import ( | ||||
| from torch.utils._pytree import (  # pyrefly: ignore  # deprecated | ||||
|     _deregister_pytree_node, | ||||
|     _register_pytree_node, | ||||
|     Context, | ||||
| @ -470,7 +470,14 @@ def _check_input_constraints_for_graph( | ||||
|                 ) | ||||
|         elif isinstance(node_val, torch.SymInt): | ||||
|             _check_symint( | ||||
|                 node_val, arg, range_constraints, unification_map, key_path, None | ||||
|                 # pyrefly: ignore  # bad-argument-type | ||||
|                 node_val, | ||||
|                 # pyrefly: ignore  # bad-argument-type | ||||
|                 arg, | ||||
|                 range_constraints, | ||||
|                 unification_map, | ||||
|                 key_path, | ||||
|                 None, | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @ -1115,12 +1122,14 @@ def placeholder_naming_pass( | ||||
|         if (  # handle targets for custom objects | ||||
|             spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map | ||||
|         ): | ||||
|             # pyrefly: ignore  # index-error | ||||
|             spec.target = name_map[spec.target][4:]  # strip obj_ prefix | ||||
|  | ||||
|     for spec in export_graph_signature.output_specs: | ||||
|         if spec.arg.name in name_map: | ||||
|             spec.arg.name = name_map[spec.arg.name] | ||||
|         if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map: | ||||
|             # pyrefly: ignore  # index-error | ||||
|             spec.target = name_map[spec.target] | ||||
|  | ||||
|     # rename keys in constants dict for custom objects | ||||
|  | ||||
| @ -96,6 +96,7 @@ class AssociativeScanOp(HigherOrderOperator): | ||||
|         validate_subgraph_args_types(additional_inputs) | ||||
|         return super().__call__(combine_fn, xs, additional_inputs) | ||||
|  | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def gen_schema(self, combine_fn, xs, additional_inputs): | ||||
|         from torch._higher_order_ops.schema import HopSchemaGenerator | ||||
|         from torch._higher_order_ops.utils import materialize_as_graph | ||||
| @ -648,6 +649,7 @@ class AssociativeScanAutogradOp(torch.autograd.Function): | ||||
|     """ | ||||
|  | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def forward( | ||||
|         ctx, | ||||
|         combine_fn, | ||||
|  | ||||
| @ -609,6 +609,7 @@ def do_auto_functionalize_v2( | ||||
|     normalized_kwargs = {} | ||||
|  | ||||
|     schema = op._schema | ||||
|     # pyrefly: ignore  # bad-assignment | ||||
|     op = op._op if isinstance(op, HopInstance) else op | ||||
|     assert isinstance(op, get_args(_MutableOpType)) | ||||
|  | ||||
|  | ||||
| @ -170,6 +170,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC): | ||||
|             out = self(functionalized_subgraph, *unwrapped_operands, **kwargs) | ||||
|         return ctx.wrap_tensors(out) | ||||
|  | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def gen_schema(self, subgraph, *operands, **kwargs): | ||||
|         from .schema import HopSchemaGenerator | ||||
|  | ||||
| @ -214,6 +215,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC): | ||||
|  | ||||
| class BaseHOPFunction(torch.autograd.Function): | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def forward(ctx, hop, subgraph, kwargs, *operands): | ||||
|         ctx.hop = hop | ||||
|         ctx.operands = operands | ||||
|  | ||||
| @ -52,6 +52,7 @@ class CondOp(HigherOrderOperator): | ||||
|         validate_subgraph_args_types(operands) | ||||
|         return super().__call__(pred, true_fn, false_fn, operands) | ||||
|  | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def gen_schema(self, pred, true_fn, false_fn, operands): | ||||
|         from torch._higher_order_ops.schema import HopSchemaGenerator | ||||
|         from torch._higher_order_ops.utils import materialize_as_graph | ||||
| @ -284,6 +285,7 @@ def cond_op_dense(pred, true_fn, false_fn, operands): | ||||
|  | ||||
| class CondAutogradOp(torch.autograd.Function): | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def forward( | ||||
|         ctx, | ||||
|         pred, | ||||
|  | ||||
| @ -298,4 +298,5 @@ def handle_effects( | ||||
|     assert isinstance(wrapped_token, torch.Tensor) | ||||
|     tokens[key] = wrapped_token | ||||
|  | ||||
|     # pyrefly: ignore  # bad-argument-type | ||||
|     return ctx.wrap_tensors(unwrapped_outs) | ||||
|  | ||||
| @ -354,12 +354,18 @@ def trace_flex_attention( | ||||
|         score_mod_other_buffers, | ||||
|         mask_mod_other_buffers, | ||||
|     ) | ||||
|     # pyrefly: ignore  # missing-attribute | ||||
|     proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) | ||||
|     out_proxy = proxy_mode.tracer.create_proxy( | ||||
|         "call_function", flex_attention, proxy_args, {} | ||||
|     ) | ||||
|     return track_tensor_tree( | ||||
|         example_out, out_proxy, constant=None, tracer=proxy_mode.tracer | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         example_out, | ||||
|         out_proxy, | ||||
|         constant=None, | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         tracer=proxy_mode.tracer, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @ -621,6 +627,7 @@ def create_fw_bw_graph( | ||||
|  | ||||
| class FlexAttentionAutogradOp(torch.autograd.Function): | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def forward( | ||||
|         ctx: Any, | ||||
|         query: Tensor, | ||||
| @ -1063,6 +1070,7 @@ def trace_flex_attention_backward( | ||||
|         score_mod_other_buffers, | ||||
|         mask_mod_other_buffers, | ||||
|     ) | ||||
|     # pyrefly: ignore  # missing-attribute | ||||
|     proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) | ||||
|     out_proxy = proxy_mode.tracer.create_proxy( | ||||
|         "call_function", | ||||
| @ -1072,7 +1080,12 @@ def trace_flex_attention_backward( | ||||
|         name="flex_attention_backward", | ||||
|     ) | ||||
|     return track_tensor_tree( | ||||
|         example_out, out_proxy, constant=None, tracer=proxy_mode.tracer | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         example_out, | ||||
|         out_proxy, | ||||
|         constant=None, | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         tracer=proxy_mode.tracer, | ||||
|     ) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -86,6 +86,7 @@ class InvokeSubgraphHOP(HigherOrderOperator): | ||||
|  | ||||
|         return super().__call__(subgraph, identifier, *operands) | ||||
|  | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def gen_schema(self, subgraph, identifier, *operands): | ||||
|         from torch._higher_order_ops.schema import HopSchemaGenerator | ||||
|         from torch._higher_order_ops.utils import ( | ||||
| @ -401,6 +402,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function): | ||||
|     """ | ||||
|  | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def forward( | ||||
|         ctx, | ||||
|         subgraph, | ||||
| @ -477,6 +479,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function): | ||||
|         for tangent in filtered_grad_outs: | ||||
|             metadata = extract_tensor_metadata(tangent) | ||||
|             metadata._flatten_into(tangent_metadata, fake_mode, state) | ||||
|         # pyrefly: ignore  # bad-assignment | ||||
|         tangent_metadata = tuple(tangent_metadata) | ||||
|  | ||||
|         # bw_graph is a joint graph with signature (*primals_and_tangents) and | ||||
|  | ||||
| @ -197,6 +197,7 @@ def create_hop_fw_bw( | ||||
|  | ||||
| class LocalMapAutogradOp(torch.autograd.Function): | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def forward( | ||||
|         ctx: Any, | ||||
|         fw_gm: GraphModule, | ||||
| @ -243,6 +244,7 @@ class LocalMapAutogradOp(torch.autograd.Function): | ||||
|             ) | ||||
|  | ||||
|             for i, meta in ctx.expected_tangent_metadata.items(): | ||||
|                 # pyrefly: ignore  # bad-argument-type | ||||
|                 grads[i] = coerce_to_expected_memory_format(grads[i], meta) | ||||
|  | ||||
|             grad_ins = local_map_hop(ctx.bw_gm, *saved_activations, *grads) | ||||
|  | ||||
| @ -125,6 +125,7 @@ def map( | ||||
|  | ||||
| class MapAutogradOp(torch.autograd.Function): | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def forward(ctx, f, num_mapped_args, *flat_args): | ||||
|         ctx._f = f | ||||
|         ctx._num_mapped_args = num_mapped_args | ||||
|  | ||||
| @ -241,6 +241,7 @@ class ScanOp(HigherOrderOperator): | ||||
|         validate_subgraph_args_types(additional_inputs) | ||||
|         return super().__call__(combine_fn, init, xs, additional_inputs) | ||||
|  | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def gen_schema(self, combine_fn, init, xs, additional_inputs): | ||||
|         from torch._higher_order_ops.schema import HopSchemaGenerator | ||||
|         from torch._higher_order_ops.utils import materialize_as_graph | ||||
| @ -448,6 +449,7 @@ class ScanAutogradOp(torch.autograd.Function): | ||||
|     """ | ||||
|  | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def forward( | ||||
|         ctx, | ||||
|         hop_partitioned_graph, | ||||
|  | ||||
| @ -292,6 +292,7 @@ def generate_ttir( | ||||
|             ordered_args[name] = 2 | ||||
|         elif ( | ||||
|             stable_meta := maybe_unpack_tma_stable_metadata( | ||||
|                 # pyrefly: ignore  # bad-argument-type | ||||
|                 tma_descriptor_metadata.get(name, None) | ||||
|             ) | ||||
|         ) is not None: | ||||
| @ -425,6 +426,7 @@ def generate_ttir( | ||||
|                         specialize_value=not kp.do_not_specialize, | ||||
|                         align=not kp.do_not_specialize_on_alignment, | ||||
|                     ) | ||||
|                     # pyrefly: ignore  # unsupported-operation | ||||
|                     attrvals.append(spec[1]) | ||||
|  | ||||
|             attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str)) | ||||
| @ -443,6 +445,7 @@ def generate_ttir( | ||||
|         def get_signature_value(idx: int, arg: Any) -> str: | ||||
|             if kernel.params[idx].is_constexpr: | ||||
|                 return "constexpr" | ||||
|             # pyrefly: ignore  # not-callable | ||||
|             return mangle_type(arg) | ||||
|  | ||||
|     else: | ||||
| @ -815,6 +818,7 @@ def get_tma_stores( | ||||
|         for op in op_list: | ||||
|             if op.name == "tt.call": | ||||
|                 assert op.fn_call_name in functions | ||||
|                 # pyrefly: ignore  # bad-argument-type | ||||
|                 tma_stores = get_tma_stores(functions, op.fn_call_name) | ||||
|                 for i, inp in enumerate(op.args): | ||||
|                     if Param(idx=i) in tma_stores: | ||||
| @ -895,7 +899,11 @@ def analyze_kernel_mutations( | ||||
|             if op.name == "tt.call": | ||||
|                 assert op.fn_call_name in functions | ||||
|                 mutations = analyze_kernel_mutations( | ||||
|                     functions, op.fn_call_name, len(op.args) | ||||
|                     # pyrefly: ignore  # bad-argument-type | ||||
|                     functions, | ||||
|                     # pyrefly: ignore  # bad-argument-type | ||||
|                     op.fn_call_name, | ||||
|                     len(op.args), | ||||
|                 ) | ||||
|                 stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated) | ||||
|             else: | ||||
| @ -948,6 +956,7 @@ def identify_mutated_tensors( | ||||
|         assert functions is not None | ||||
|         kernel_name = next(iter(functions.keys())) | ||||
|         # Triton codegen modifies the name | ||||
|         # pyrefly: ignore  # missing-attribute | ||||
|         assert kernel.fn.__name__ in kernel_name | ||||
|         # Reset the cache between top level invocations | ||||
|         # The cache for analyze kernel mutations is mainly used for cycle | ||||
| @ -1051,7 +1060,11 @@ def triton_kernel_wrapper_mutation_dense( | ||||
|         grid_fn = grid[0] | ||||
|     else: | ||||
|         fn_name, code = user_defined_kernel_grid_fn_code( | ||||
|             kernel.fn.__name__, kernel.configs, grid | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             kernel.fn.__name__, | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             kernel.configs, | ||||
|             grid, | ||||
|         ) | ||||
|         namespace: dict[str, Any] = {} | ||||
|         exec(code, namespace) | ||||
| @ -1100,6 +1113,7 @@ def triton_kernel_wrapper_mutation_dense( | ||||
|     # avoid mutating the original inputs | ||||
|     kwargs = kwargs.copy() | ||||
|     constant_args = constant_args.copy() | ||||
|     # pyrefly: ignore  # missing-attribute | ||||
|     for name in kernel.arg_names: | ||||
|         if name in kwargs: | ||||
|             args.append(kwargs.pop(name)) | ||||
| @ -1108,6 +1122,7 @@ def triton_kernel_wrapper_mutation_dense( | ||||
|         else: | ||||
|             break | ||||
|  | ||||
|     # pyrefly: ignore  # index-error | ||||
|     kernel[grid_fn](*args, **kwargs, **constant_args) | ||||
|  | ||||
|  | ||||
| @ -1513,6 +1528,7 @@ class TritonHOPifier: | ||||
|  | ||||
|         assert kernel_idx is None or variable.kernel_idx == kernel_idx | ||||
|  | ||||
|         # pyrefly: ignore  # bad-assignment | ||||
|         variable.grid = grid | ||||
|  | ||||
|         if isinstance(kernel, Autotuner): | ||||
| @ -2057,6 +2073,7 @@ class TraceableTritonKernelWrapper: | ||||
|             return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None) | ||||
|         else: | ||||
|             assert self.kernel is not None | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             return self.kernel.run(*args, **kwargs) | ||||
|  | ||||
|     def __call__(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any: | ||||
| @ -2068,6 +2085,7 @@ class TraceableTritonKernelWrapper: | ||||
|             ) | ||||
|         else: | ||||
|             assert self.kernel is not None | ||||
|             # pyrefly: ignore  # index-error | ||||
|             return self.kernel[self.grid](*args, **kwargs) | ||||
|  | ||||
|     def specialize_symbolic(self, arg: Sequence[Any]) -> Any: | ||||
|  | ||||
| @ -270,6 +270,7 @@ def _set_compilation_env(): | ||||
|         # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo | ||||
|         # once we are confident fx tracing works with dynamo. | ||||
|         torch.fx._symbolic_trace._is_fx_tracing_flag = False | ||||
|         # pyrefly: ignore  # bad-assignment | ||||
|         torch._dynamo.config.allow_empty_graphs = True | ||||
|         torch._dynamo.config.capture_scalar_outputs = True | ||||
|         yield | ||||
| @ -440,6 +441,7 @@ def unique_graph_name_with_root( | ||||
| ) -> tuple[int, str]: | ||||
|     next_name = None | ||||
|     i = 0 | ||||
|     # pyrefly: ignore  # bad-assignment | ||||
|     while not next_name: | ||||
|         candidate = f"{prefix}_{i}" | ||||
|         if hasattr(root, candidate): | ||||
| @ -795,6 +797,8 @@ def create_bw_fn( | ||||
|     """ | ||||
|  | ||||
|     from torch._functorch.aot_autograd import AOTConfig, create_joint | ||||
|  | ||||
|     # pyrefly: ignore  # missing-module-attribute | ||||
|     from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad | ||||
|  | ||||
|     dummy_aot_config = AOTConfig( | ||||
| @ -939,6 +943,7 @@ def check_input_alias_and_mutation( | ||||
|         out_out_alias_map, | ||||
|         mutated_inputs, | ||||
|     ) = check_input_alias_and_mutation_return_outputs(gm)[:-1] | ||||
|     # pyrefly: ignore  # bad-return | ||||
|     return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -54,6 +54,7 @@ class WhileLoopOp(HigherOrderOperator): | ||||
|         validate_subgraph_args_types(additional_inputs) | ||||
|         return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs) | ||||
|  | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def gen_schema(self, cond_fn, body_fn, carried_inputs, additional_inputs): | ||||
|         from torch._higher_order_ops.schema import HopSchemaGenerator | ||||
|         from torch._higher_order_ops.utils import materialize_as_graph | ||||
| @ -430,6 +431,7 @@ def while_loop_tracing( | ||||
|             elif isinstance(x, torch.Tensor): | ||||
|                 x = x.clone() | ||||
|                 if hasattr(x, "constant") and x.constant is not None: | ||||
|                     # pyrefly: ignore  # missing-attribute | ||||
|                     x.constant = None | ||||
|             return x | ||||
|  | ||||
| @ -452,6 +454,7 @@ def while_loop_tracing( | ||||
|  | ||||
|         next_name = None | ||||
|         i = 0 | ||||
|         # pyrefly: ignore  # bad-assignment | ||||
|         while not next_name: | ||||
|             candidate = f"while_loop_cond_graph_{i}" | ||||
|             if hasattr(proxy_mode.tracer.root, candidate): | ||||
| @ -696,6 +699,7 @@ class WhileLoopStackOutputOp(HigherOrderOperator): | ||||
|  | ||||
| class WhileLoopAutogradOp(torch.autograd.Function): | ||||
|     @staticmethod | ||||
|     # pyrefly: ignore  # bad-override | ||||
|     def forward( | ||||
|         ctx, | ||||
|         cond_fn, | ||||
| @ -725,6 +729,7 @@ class WhileLoopAutogradOp(torch.autograd.Function): | ||||
|         ctx.additional_inputs = additional_inputs | ||||
|         ctx.fw_outputs = fw_outputs | ||||
|         loop_count = None | ||||
|         # pyrefly: ignore  # bad-assignment | ||||
|         for out in fw_outputs: | ||||
|             if isinstance(out, torch.Tensor): | ||||
|                 if loop_count is not None: | ||||
| @ -878,6 +883,7 @@ class WhileLoopAutogradOp(torch.autograd.Function): | ||||
|             while_loop_op( | ||||
|                 cond_gm, | ||||
|                 body_gm, | ||||
|                 # pyrefly: ignore  # bad-argument-type | ||||
|                 ( | ||||
|                     init_idx, | ||||
|                     *init_grad_carries, | ||||
|  | ||||
| @ -880,10 +880,14 @@ def logsumexp( | ||||
|     if not isinstance(dim, Iterable): | ||||
|         dim = (dim,) | ||||
|     if self.numel() == 0: | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         return torch.sum(torch.exp(self), dim, keepdim).log() | ||||
|     # pyrefly: ignore  # bad-argument-type | ||||
|     maxes = torch.amax(torch.real(self), dim, keepdim=True) | ||||
|     maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0) | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim) | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     result = torch.sum(torch.exp(self - maxes), dim, keepdim) | ||||
|     return result.log().add(maxes_squeezed) | ||||
|  | ||||
| @ -1241,10 +1245,12 @@ def copysign( | ||||
|     a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] | ||||
| ): | ||||
|     if isinstance(b, Number) and isinstance(a, Tensor): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         b = scalar_tensor(b, dtype=a.dtype, device=a.device) | ||||
|     elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: | ||||
|         msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!" | ||||
|         raise RuntimeError(msg) | ||||
|     # pyrefly: ignore  # bad-argument-type | ||||
|     return where(signbit(b), neg(abs(a)), abs(a)) | ||||
|  | ||||
|  | ||||
| @ -1330,10 +1336,13 @@ def float_power( | ||||
|  | ||||
|     # Float power has the following contiguous cast behavior to be | ||||
|     # consistent with its C++ impl | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     a = _maybe_convert_to_dtype(a, dtype) | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     b = _maybe_convert_to_dtype(b, dtype) | ||||
|  | ||||
|     a, b = _maybe_broadcast(a, b) | ||||
|     # pyrefly: ignore  # bad-return | ||||
|     return pow(a, b) | ||||
|  | ||||
|  | ||||
| @ -1375,11 +1384,15 @@ def floor_divide( | ||||
| ): | ||||
|     # Wrap scalars because some references only accept tensor arguments. | ||||
|     if isinstance(a, Number) and isinstance(b, Number): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         a = scalar_tensor(a) | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         b = scalar_tensor(b) | ||||
|     elif isinstance(b, Number) and isinstance(a, Tensor): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         b = scalar_tensor(b, dtype=a.dtype, device=a.device) | ||||
|     elif isinstance(a, Number) and isinstance(b, Tensor): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         a = scalar_tensor(a, dtype=b.dtype, device=b.device) | ||||
|     elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: | ||||
|         if a.device == torch.device("cpu"): | ||||
| @ -1856,8 +1869,10 @@ def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberT | ||||
|  | ||||
|     # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. | ||||
|     if isinstance(b, TensorLike) and isinstance(a, Number): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         a = scalar_tensor(a, dtype=b.dtype, device=b.device) | ||||
|     elif isinstance(a, TensorLike) and isinstance(b, Number): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         b = scalar_tensor(b, dtype=a.dtype, device=a.device) | ||||
|  | ||||
|     # mypy: expected "Tensor" | ||||
| @ -2333,6 +2348,7 @@ def all( | ||||
|     dim: Optional[DimsType] = None, | ||||
|     keepdim: bool = False, | ||||
| ) -> TensorLikeType: | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim)) | ||||
|  | ||||
|     if a.dtype == torch.uint8: | ||||
| @ -2850,6 +2866,7 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: | ||||
|     # SymInts | ||||
|  | ||||
|     example = None | ||||
|     # pyrefly: ignore  # bad-assignment | ||||
|     for i, t in enumerate(tensors): | ||||
|         if example is None: | ||||
|             if t.ndim != 1: | ||||
| @ -3228,6 +3245,7 @@ def _normalize( | ||||
|         mean (Tensor): mean of the tensor along norm_dims. | ||||
|         rstd (Tensor): 1/std of the tensor along norm_dims. | ||||
|     """ | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) | ||||
|     computation_dtype = utils.get_computation_dtype(a.dtype) | ||||
|     a_acc = _maybe_convert_to_dtype(a, computation_dtype) | ||||
| @ -3341,6 +3359,7 @@ def native_layer_norm( | ||||
|     # while torch.Size([1, 2, 3]) == (1, 2, 3) is True | ||||
|     # therefore we use tuple(normalized_shape) | ||||
|     torch._check( | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         weight is None or sym_eq(weight.shape, tuple(normalized_shape)), | ||||
|         lambda: "Expected weight to be of same shape as normalized_shape, but got " | ||||
|         + "weight of shape " | ||||
| @ -3349,6 +3368,7 @@ def native_layer_norm( | ||||
|         + str(normalized_shape), | ||||
|     ) | ||||
|     torch._check( | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         bias is None or sym_eq(bias.shape, tuple(normalized_shape)), | ||||
|         lambda: "Expected bias to be of same shape as normalized_shape, but got " | ||||
|         + "bias of shape " | ||||
| @ -3359,7 +3379,10 @@ def native_layer_norm( | ||||
|     torch._check( | ||||
|         input.ndim >= normalized_ndim | ||||
|         and sym_eq( | ||||
|             input.shape[(input.ndim - normalized_ndim) :], tuple(normalized_shape) | ||||
|             # pyrefly: ignore  # bad-argument-type | ||||
|             input.shape[(input.ndim - normalized_ndim) :], | ||||
|             # pyrefly: ignore  # bad-argument-type | ||||
|             tuple(normalized_shape), | ||||
|         ), | ||||
|         lambda: "Given normalized_shape=" | ||||
|         + str(normalized_shape) | ||||
| @ -3953,6 +3976,7 @@ def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: | ||||
| @out_wrapper() | ||||
| def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType: | ||||
|     """Reference implementation of :func:`torch.roll`.""" | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     dims = utils.canonicalize_dims(a.ndim, dims) | ||||
|     # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1 | ||||
|     if not isinstance(shifts, Iterable): | ||||
| @ -3965,12 +3989,16 @@ def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLike | ||||
|         # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors | ||||
|         return a.clone() | ||||
|  | ||||
|     # pyrefly: ignore  # bad-argument-type | ||||
|     if a.dim() == 0 and len(dims) > 0: | ||||
|         raise IndexError( | ||||
|             # pyrefly: ignore  # index-error | ||||
|             f"Dimension specified as {dims[0]} but tensor has no dimensions" | ||||
|         ) | ||||
|  | ||||
|     # pyrefly: ignore  # bad-argument-type | ||||
|     len_shifts = len(shifts) | ||||
|     # pyrefly: ignore  # bad-argument-type | ||||
|     len_dims = len(dims) | ||||
|     if len_shifts != 1 or len_dims != 1: | ||||
|         if len_shifts == 0: | ||||
| @ -3978,21 +4006,27 @@ def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLike | ||||
|         # Takes care of the case when dims is not specified (default) | ||||
|         # By default, the tensor is flattened before shifting, after which the original shape is restored | ||||
|         if len_dims == 0 and len_shifts == 1: | ||||
|             # pyrefly: ignore  # bad-argument-type | ||||
|             return torch.roll(torch.flatten(a), shifts, 0).view(a.shape) | ||||
|         if len_shifts != len_dims: | ||||
|             raise RuntimeError( | ||||
|                 f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}" | ||||
|             ) | ||||
|         assert len_dims > 1 | ||||
|         # pyrefly: ignore  # index-error | ||||
|         tail_shifts = shifts[1:] | ||||
|         # pyrefly: ignore  # index-error | ||||
|         tail_dims = dims[1:] | ||||
|         # pyrefly: ignore  # index-error | ||||
|         first_dim_rolled = torch.roll(a, (shifts[0],), dims[0]) | ||||
|         return torch.roll(first_dim_rolled, tail_shifts, tail_dims) | ||||
|  | ||||
|     # This path is taken when only one dimension is rolled | ||||
|     # For example to get `first_dim_rolled` above | ||||
|     # pyrefly: ignore  # index-error | ||||
|     dim = dims[0] | ||||
|     size = a.shape[dim] | ||||
|     # pyrefly: ignore  # index-error | ||||
|     start = (size - shifts[0]) % size | ||||
|     idx = torch.arange(size, device=a.device) | ||||
|     return a.index_select(dim, torch.fmod(start + idx, size)) | ||||
| @ -4074,7 +4108,9 @@ def softmax( | ||||
|         a_max = amax(a_, dim, keepdim=True) | ||||
|         a_exp = exp(a_ - a_max) | ||||
|     return _maybe_convert_to_dtype( | ||||
|         true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         true_divide(a_exp, sum(a_exp, dim, keepdim=True)), | ||||
|         result_dtype, | ||||
|     )  # type: ignore[return-value] | ||||
|  | ||||
|  | ||||
| @ -4251,6 +4287,7 @@ def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType | ||||
|         return prims.squeeze(a, dims) if dims else prims.view_of(a) | ||||
|  | ||||
|     ndim = a.ndim | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     dim = utils.canonicalize_dims(ndim, dim) | ||||
|     dims = (dim,) if isinstance(dim, Dim) else dim | ||||
|     # Short-circuits if the tensor has no dimensions | ||||
| @ -4391,6 +4428,7 @@ def hsplit( | ||||
|     if isinstance(indices_or_sections, IntLike): | ||||
|         split_size = indices_or_sections | ||||
|         torch._check( | ||||
|             # pyrefly: ignore  # unsupported-operation | ||||
|             (split_size != 0 and a.shape[dim] % split_size == 0), | ||||
|             lambda: ( | ||||
|                 "torch.hsplit attempted to split along dimension " | ||||
| @ -4402,6 +4440,7 @@ def hsplit( | ||||
|                 + "!" | ||||
|             ), | ||||
|         ) | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         return tensor_split(a, split_size, dim) | ||||
|  | ||||
|     torch._check_type( | ||||
| @ -4432,6 +4471,7 @@ def vsplit( | ||||
|     if isinstance(indices_or_sections, IntLike): | ||||
|         split_size = indices_or_sections | ||||
|         torch._check( | ||||
|             # pyrefly: ignore  # unsupported-operation | ||||
|             (split_size != 0 and a.shape[0] % split_size == 0), | ||||
|             lambda: ( | ||||
|                 f"torch.vsplit attempted to split along dimension 0" | ||||
| @ -4442,6 +4482,7 @@ def vsplit( | ||||
|                 f"!" | ||||
|             ), | ||||
|         ) | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         return tensor_split(a, split_size, 0) | ||||
|  | ||||
|     torch._check_type( | ||||
| @ -4646,6 +4687,7 @@ def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType: | ||||
|         raise RuntimeError( | ||||
|             f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!" | ||||
|         ) | ||||
|     # pyrefly: ignore  # unsupported-operation | ||||
|     if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0): | ||||
|         raise RuntimeError( | ||||
|             "torch.dsplit attempted to split along dimension 2, " | ||||
| @ -5419,6 +5461,7 @@ def logspace( | ||||
|  | ||||
|  | ||||
| @overload | ||||
| # pyrefly: ignore  # inconsistent-overload | ||||
| def meshgrid(tensors: Sequence[TensorLikeType], indexing: str): | ||||
|     pass | ||||
|  | ||||
| @ -5845,6 +5888,7 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi | ||||
|  | ||||
|     # Since `where` allows type-promotion, | ||||
|     # cast value to correct type before passing to `where` | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     value = _maybe_convert_to_dtype(value, a.dtype) | ||||
|     r = torch.where(mask, value, a)  # type: ignore[arg-type] | ||||
|  | ||||
| @ -6639,6 +6683,7 @@ def _infer_scalar_type(obj): | ||||
|         # double. | ||||
|         if length == 0: | ||||
|             return torch.get_default_dtype() | ||||
|         # pyrefly: ignore  # bad-assignment | ||||
|         for i in range(length): | ||||
|             cur_item = obj[i] | ||||
|             # TODO: test this | ||||
| @ -6676,6 +6721,7 @@ def _recursive_build( | ||||
|         # torch.Size([1, 2]) | ||||
|         return obj.detach().to(dtype=scalarType, device="cpu", copy=True) | ||||
|     elif isinstance(obj, Number): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         return torch.scalar_tensor(obj, dtype=scalarType) | ||||
|  | ||||
|     # seq can be a list of tensors | ||||
|  | ||||
| @ -106,11 +106,13 @@ def _resize_fft_input( | ||||
|         if x_sizes[dims[i]] < sizes[i]: | ||||
|             must_copy = True | ||||
|             pad_idx = len(pad_amount) - 2 * dims[i] - 1 | ||||
|             # pyrefly: ignore  # unsupported-operation | ||||
|             pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]] | ||||
|  | ||||
|         if x_sizes[dims[i]] > sizes[i]: | ||||
|             x = x.narrow(dims[i], 0, sizes[i]) | ||||
|  | ||||
|     # pyrefly: ignore  # bad-argument-type | ||||
|     return torch.constant_pad_nd(x, pad_amount) if must_copy else x | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -216,6 +216,7 @@ def matrix_norm( | ||||
|     # shape | ||||
|     check_is_matrix(A, "linalg.matrix_norm") | ||||
|     # dim | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     dim = utils.canonicalize_dims(A.ndim, dim) | ||||
|     if isinstance(dim, Dim): | ||||
|         dim = (dim,)  # type: ignore[assignment] | ||||
| @ -223,7 +224,9 @@ def matrix_norm( | ||||
|         len(dim) == 2, lambda: f"linalg.matrix_norm: dim must be a 2-tuple. Got {dim}" | ||||
|     ) | ||||
|     torch._check( | ||||
|         # pyrefly: ignore  # index-error | ||||
|         dim[0] != dim[1], | ||||
|         # pyrefly: ignore  # index-error | ||||
|         lambda: f"linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", | ||||
|     ) | ||||
|     # dtype arg | ||||
| @ -245,6 +248,7 @@ def matrix_norm( | ||||
|         else:  # ord == "nuc" | ||||
|             if dtype is not None: | ||||
|                 A = _maybe_convert_to_dtype(A, dtype)  # type: ignore[assignment] | ||||
|             # pyrefly: ignore  # index-error | ||||
|             perm = _backshift_permutation(dim[0], dim[1], A.ndim) | ||||
|             result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) | ||||
|             if keepdim: | ||||
| @ -268,6 +272,7 @@ def matrix_norm( | ||||
|         if abs_ord == 2.0: | ||||
|             if dtype is not None: | ||||
|                 A = _maybe_convert_to_dtype(A, dtype)  # type: ignore[assignment] | ||||
|             # pyrefly: ignore  # index-error | ||||
|             perm = _backshift_permutation(dim[0], dim[1], A.ndim) | ||||
|             result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) | ||||
|             if keepdim: | ||||
| @ -275,6 +280,7 @@ def matrix_norm( | ||||
|                 result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) | ||||
|             return result | ||||
|         else:  # 1, -1, inf, -inf | ||||
|             # pyrefly: ignore  # bad-unpacking | ||||
|             dim0, dim1 = dim | ||||
|             if abs_ord == float("inf"): | ||||
|                 dim0, dim1 = dim1, dim0 | ||||
|  | ||||
| @ -142,9 +142,11 @@ def _inplace_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]: | ||||
|     # nb. We use the name of the first argument used in the unary references | ||||
|     @wraps(fn) | ||||
|     def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: | ||||
|         # pyrefly: ignore  # unsupported-operation | ||||
|         a = args[0] | ||||
|         if "inplace" not in kwargs: | ||||
|             kwargs["inplace"] = False | ||||
|         # pyrefly: ignore  # unsupported-operation | ||||
|         if kwargs["inplace"]: | ||||
|             torch._check( | ||||
|                 "out" not in kwargs, | ||||
| @ -625,6 +627,7 @@ def smooth_l1_loss( | ||||
|         ) | ||||
|     else: | ||||
|         loss = torch.abs(input - target) | ||||
|         # pyrefly: ignore  # unsupported-operation | ||||
|         loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) | ||||
|         return _apply_loss_reduction(loss, reduction) | ||||
|  | ||||
|  | ||||
| @ -155,8 +155,10 @@ def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, Numbe | ||||
|  | ||||
|     # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. | ||||
|     if isinstance(a, TensorLike) and isinstance(b, Number): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device) | ||||
|     elif isinstance(b, TensorLike) and isinstance(a, Number): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device) | ||||
|  | ||||
|     # mypy: expected "Tensor" | ||||
|  | ||||
| @ -130,6 +130,7 @@ def var_decomposition( | ||||
|         else: | ||||
|             raise RuntimeError("correction must be int or float") | ||||
|  | ||||
|     # pyrefly: ignore  # no-matching-overload | ||||
|     return sum / max(0, denom) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -14,6 +14,8 @@ import torch | ||||
| _IS_MONKEYTYPE_INSTALLED = True | ||||
| try: | ||||
|     import monkeytype  # type: ignore[import] | ||||
|  | ||||
|     # pyrefly: ignore  # import-error | ||||
|     from monkeytype import trace as monkeytype_trace | ||||
|     from monkeytype.config import _startswith, LIB_PATHS  # type: ignore[import] | ||||
|     from monkeytype.db.base import (  # type: ignore[import] | ||||
| @ -87,6 +89,7 @@ if _IS_MONKEYTYPE_INSTALLED: | ||||
|             super().__init__(store) | ||||
|  | ||||
|         def log(self, trace: CallTrace) -> None: | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             self.traces.append(trace) | ||||
|  | ||||
|     class JitTypeTraceStore(CallTraceStore): | ||||
| @ -148,6 +151,7 @@ if _IS_MONKEYTYPE_INSTALLED: | ||||
|  | ||||
|         def trace_logger(self) -> JitTypeTraceStoreLogger: | ||||
|             """Return a JitCallTraceStoreLogger that logs to the configured trace store.""" | ||||
|             # pyrefly: ignore  # bad-argument-count | ||||
|             return JitTypeTraceStoreLogger(self.trace_store()) | ||||
|  | ||||
|         def trace_store(self) -> CallTraceStore: | ||||
|  | ||||
| @ -747,6 +747,7 @@ def get_overload_annotations(mod, jit_ignored_properties): | ||||
|             if method_overloads is None: | ||||
|                 continue | ||||
|  | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             if item.__func__ in method_overloads: | ||||
|                 raise RuntimeError( | ||||
|                     _jit_internal.get_overload_no_implementation_error_message( | ||||
|  | ||||
| @ -545,6 +545,7 @@ if _enabled: | ||||
|                 # | ||||
|                 # This ensures that if we use the attr again in `__init__`, it | ||||
|                 # will look like the actual value, not an instance of Attribute. | ||||
|                 # pyrefly: ignore  # invalid-argument | ||||
|                 if isinstance(value, Attribute): | ||||
|                     # NB: Ensure that we set __annotations__ on the specific | ||||
|                     # class in question, and not on a superclass (which would | ||||
| @ -656,6 +657,7 @@ if _enabled: | ||||
|  | ||||
|             # Finalize the ScriptModule: replace the nn.Module state with our | ||||
|             # custom implementations and flip the _initializing bit. | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             RecursiveScriptModule._finalize_scriptmodule(script_module) | ||||
|             return script_module | ||||
|  | ||||
| @ -929,6 +931,7 @@ if _enabled: | ||||
|                 # Don't do anything here, we'll initialize the ScriptModule below | ||||
|                 return | ||||
|  | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             return RecursiveScriptModule._construct( | ||||
|                 self._c._replicate_for_data_parallel(), init_fn | ||||
|             ) | ||||
| @ -938,6 +941,7 @@ if _enabled: | ||||
|     # This is because `super().foo()` does not use | ||||
|     # `__getattr__` to look up `foo`. So we need to make each method available on | ||||
|     # the ScriptModule manually. | ||||
|     # pyrefly: ignore  # missing-attribute | ||||
|     for name, item in RecursiveScriptModule.__dict__.items(): | ||||
|         if not callable(item) and not isinstance(item, property): | ||||
|             continue | ||||
| @ -1006,6 +1010,7 @@ if _enabled: | ||||
|         if name.startswith("__") or name.endswith("_call_impl"): | ||||
|             continue | ||||
|         if ( | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             name not in RecursiveScriptModule.__dict__ | ||||
|             and name not in _compiled_methods_allowlist | ||||
|         ): | ||||
| @ -1038,6 +1043,7 @@ def call_prepare_scriptable_func_impl(obj, memo): | ||||
|         return memo[id(obj)] | ||||
|  | ||||
|     obj = ( | ||||
|         # pyrefly: ignore  # not-callable | ||||
|         obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj | ||||
|     )  # type: ignore[operator] | ||||
|     # Record obj in memo to avoid infinite recursion in the case of cycles in the module | ||||
| @ -1135,6 +1141,7 @@ def _script_impl( | ||||
|         # the provide example inputs. This logs all the traces in type_trace_db | ||||
|         type_trace_db = JitTypeTraceStore() | ||||
|         if monkeytype_trace: | ||||
|             # pyrefly: ignore  # bad-argument-count | ||||
|             monkeytype_config = JitTypeTraceConfig(type_trace_db) | ||||
|             with monkeytype_trace(monkeytype_config): | ||||
|                 if isinstance(example_inputs, dict): | ||||
|  | ||||
| @ -166,11 +166,25 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False): | ||||
|     cu = torch._C.CompilationUnit() | ||||
|     if isinstance(f, (str, os.PathLike)): | ||||
|         cpp_module = torch._C.import_ir_module( | ||||
|             cu, os.fspath(f), map_location, _extra_files, _restore_shapes | ||||
|             # pyrefly: ignore  # no-matching-overload, bad-argument-count | ||||
|             cu, | ||||
|             # pyrefly: ignore  # no-matching-overload | ||||
|             os.fspath(f), | ||||
|             map_location, | ||||
|             _extra_files, | ||||
|             # pyrefly: ignore  # bad-argument-count | ||||
|             _restore_shapes, | ||||
|         )  # type: ignore[call-arg] | ||||
|     else: | ||||
|         cpp_module = torch._C.import_ir_module_from_buffer( | ||||
|             cu, f.read(), map_location, _extra_files, _restore_shapes | ||||
|             # pyrefly: ignore  # missing-attribute, bad-argument-count | ||||
|             cu, | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             f.read(), | ||||
|             map_location, | ||||
|             _extra_files, | ||||
|             # pyrefly: ignore  # bad-argument-count | ||||
|             _restore_shapes, | ||||
|         )  # type: ignore[call-arg] | ||||
|  | ||||
|     # TODO: Pretty sure this approach loses ConstSequential status and such | ||||
| @ -196,6 +210,7 @@ def validate_map_location(map_location=None): | ||||
|  | ||||
| def jit_module_from_flatbuffer(f): | ||||
|     if isinstance(f, (str, os.PathLike)): | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         f = os.fspath(f) | ||||
|         return wrap_cpp_module(torch._C._load_jit_module_from_file(f)) | ||||
|     else: | ||||
| @ -245,6 +260,7 @@ def save_jit_module_to_flatbuffer(m, f, _extra_files=None): | ||||
|         extra_files = {} | ||||
|  | ||||
|     if isinstance(f, (str, os.PathLike)): | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         f = os.fspath(f) | ||||
|         torch._C._save_jit_module(m._c, f, extra_files) | ||||
|     else: | ||||
|  | ||||
| @ -561,6 +561,7 @@ def cat(tensors: list[list[int]], dim: int): | ||||
|     for i in range(len(tensors)): | ||||
|         tensor = tensors[i] | ||||
|         if not should_skip(tensor): | ||||
|             # pyrefly: ignore  # bad-argument-type | ||||
|             check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i) | ||||
|             cat_dim_size = cat_dim_size + tensor[dim] | ||||
|  | ||||
|  | ||||
| @ -169,6 +169,7 @@ def _clone_inputs(args): | ||||
|         else: | ||||
|             return a.clone(memory_format=torch.preserve_format) | ||||
|  | ||||
|     # pyrefly: ignore  # missing-attribute | ||||
|     return function._nested_map( | ||||
|         lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors" | ||||
|     )(args) | ||||
| @ -335,6 +336,7 @@ def _check_trace( | ||||
|  | ||||
|         if is_trace_module: | ||||
|             copied_dict = {} | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             for name, data in inputs.items(): | ||||
|                 copied_dict[name] = _clone_inputs(data) | ||||
|             check_mod = torch.jit.trace_module( | ||||
| @ -739,6 +741,7 @@ def _trace_impl( | ||||
|         example_inputs = (example_inputs,) | ||||
|     # done primarily so that weird iterables fail here and not pybind11 code | ||||
|     elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple): | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         example_inputs = tuple(example_inputs) | ||||
|  | ||||
|     var_lookup_fn = _create_interpreter_name_lookup_fn(0) | ||||
| @ -765,6 +768,7 @@ def _trace_impl( | ||||
|         traced = torch._C._create_function_from_trace( | ||||
|             name, | ||||
|             func, | ||||
|             # pyrefly: ignore  # bad-argument-type | ||||
|             example_inputs, | ||||
|             var_lookup_fn, | ||||
|             strict, | ||||
|  | ||||
| @ -98,6 +98,7 @@ class EvalEnv: | ||||
|     def __init__(self, rcb): | ||||
|         self.rcb = rcb | ||||
|         if torch.distributed.rpc.is_available(): | ||||
|             # pyrefly: ignore  # unsupported-operation | ||||
|             self.env["RRef"] = RRef | ||||
|  | ||||
|     def __getitem__(self, name): | ||||
|  | ||||
| @ -115,6 +115,7 @@ node_start_tokens = { | ||||
|     ast.Continue: "continue", | ||||
| } | ||||
|  | ||||
| # pyrefly: ignore  # no-matching-overload | ||||
| pretty_node_names.update( | ||||
|     { | ||||
|         ast.AsyncFunctionDef: "async function definitions", | ||||
| @ -125,6 +126,7 @@ pretty_node_names.update( | ||||
|     } | ||||
| ) | ||||
|  | ||||
| # pyrefly: ignore  # no-matching-overload | ||||
| node_start_tokens.update( | ||||
|     { | ||||
|         ast.AsyncFunctionDef: "async def", | ||||
| @ -135,6 +137,7 @@ node_start_tokens.update( | ||||
|     } | ||||
| ) | ||||
|  | ||||
| # pyrefly: ignore  # no-matching-overload | ||||
| pretty_node_names.update( | ||||
|     { | ||||
|         ast.AnnAssign: "annotated assignments", | ||||
| @ -859,6 +862,7 @@ class ExprBuilder(Builder): | ||||
|         ast.RShift: ">>", | ||||
|     } | ||||
|  | ||||
|     # pyrefly: ignore  # unsupported-operation | ||||
|     binop_map[ast.MatMult] = "@" | ||||
|  | ||||
|     unop_map = { | ||||
| @ -1220,6 +1224,7 @@ class ExprBuilder(Builder): | ||||
|                 s += "{}" | ||||
|                 args.append(build_expr(ctx, value.value)) | ||||
|             elif isinstance(value, ast.Constant): | ||||
|                 # pyrefly: ignore  # unsupported-operation | ||||
|                 s += value.value | ||||
|             else: | ||||
|                 raise NotSupportedError(r, "Unsupported value in JoinedStr") | ||||
|  | ||||
| @ -44,10 +44,13 @@ def _load_for_lite_interpreter(f, map_location=None): | ||||
|     map_location = validate_map_location(map_location) | ||||
|  | ||||
|     if isinstance(f, (str, os.PathLike)): | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location) | ||||
|     else: | ||||
|         cpp_module = torch._C._load_for_lite_interpreter_from_buffer( | ||||
|             f.read(), map_location | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             f.read(), | ||||
|             map_location, | ||||
|         ) | ||||
|  | ||||
|     return LiteScriptModule(cpp_module) | ||||
| @ -103,8 +106,10 @@ def _get_model_bytecode_version(f_input) -> int: | ||||
|             raise ValueError(f"The provided filename {f_input} is a directory") | ||||
|  | ||||
|     if isinstance(f_input, (str, os.PathLike)): | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         return torch._C._get_model_bytecode_version(os.fspath(f_input)) | ||||
|     else: | ||||
|         # pyrefly: ignore  # missing-attribute | ||||
|         return torch._C._get_model_bytecode_version_from_buffer(f_input.read()) | ||||
|  | ||||
|  | ||||
| @ -135,8 +140,10 @@ def _get_mobile_model_contained_types(f_input) -> int: | ||||
|             raise ValueError(f"The provided filename {f_input} is a directory") | ||||
|  | ||||
|     if isinstance(f_input, (str, os.PathLike)): | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         return torch._C._get_mobile_model_contained_types(os.fspath(f_input)) | ||||
|     else: | ||||
|         # pyrefly: ignore  # missing-attribute | ||||
|         return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read()) | ||||
|  | ||||
|  | ||||
| @ -161,11 +168,18 @@ def _backport_for_mobile(f_input, f_output, to_version): | ||||
|         isinstance(f_output, (str, os.PathLike)) | ||||
|     ): | ||||
|         return torch._C._backport_for_mobile( | ||||
|             os.fspath(f_input), os.fspath(f_output), to_version | ||||
|             # pyrefly: ignore  # no-matching-overload | ||||
|             os.fspath(f_input), | ||||
|             # pyrefly: ignore  # no-matching-overload | ||||
|             os.fspath(f_output), | ||||
|             to_version, | ||||
|         ) | ||||
|     else: | ||||
|         return torch._C._backport_for_mobile_from_buffer( | ||||
|             f_input.read(), str(f_output), to_version | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             f_input.read(), | ||||
|             str(f_output), | ||||
|             to_version, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -184,10 +198,13 @@ def _backport_for_mobile_to_buffer(f_input, to_version): | ||||
|             raise ValueError(f"The provided filename {f_input} is a directory") | ||||
|  | ||||
|     if isinstance(f_input, (str, os.PathLike)): | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version) | ||||
|     else: | ||||
|         return torch._C._backport_for_mobile_from_buffer_to_buffer( | ||||
|             f_input.read(), to_version | ||||
|             # pyrefly: ignore  # missing-attribute | ||||
|             f_input.read(), | ||||
|             to_version, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -227,6 +244,8 @@ def _get_model_ops_and_info(f_input): | ||||
|             raise ValueError(f"The provided filename {f_input} is a directory") | ||||
|  | ||||
|     if isinstance(f_input, (str, os.PathLike)): | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         return torch._C._get_model_ops_and_info(os.fspath(f_input)) | ||||
|     else: | ||||
|         # pyrefly: ignore  # missing-attribute | ||||
|         return torch._C._get_model_ops_and_info(f_input.read()) | ||||
|  | ||||
| @ -261,6 +261,7 @@ def _get_global_builtins(): | ||||
|  | ||||
|     magic_methods_rows = [] | ||||
|     for fn, magic_method in magic_methods: | ||||
|         # pyrefly: ignore  # bad-argument-type | ||||
|         magic_methods_rows.append(f'"{fn}", "``{magic_method}``"') | ||||
|  | ||||
|     schematized_ops = [] | ||||
| @ -279,6 +280,7 @@ def _get_global_builtins(): | ||||
|             table_row = ( | ||||
|                 f'":external+python:py:obj:`{fn}`", "{schemaless_op_explanations[fn]}"' | ||||
|             ) | ||||
|             # pyrefly: ignore  # bad-argument-type | ||||
|             schemaless_ops.append(table_row) | ||||
|  | ||||
|     schematized_ops_str = "\n".join(schematized_ops) | ||||
|  | ||||
| @ -78,6 +78,7 @@ def _adjust_lr( | ||||
|     A, B = param_shape[:2] | ||||
|  | ||||
|     if adjust_lr_fn is None or adjust_lr_fn == "original": | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         adjusted_ratio = math.sqrt(max(1, A / B)) | ||||
|     elif adjust_lr_fn == "match_rms_adamw": | ||||
|         adjusted_ratio = 0.2 * math.sqrt(max(A, B)) | ||||
|  | ||||
| @ -415,6 +415,7 @@ def _single_tensor_adam( | ||||
|                     if weight_decay.requires_grad: | ||||
|                         grad = grad.addcmul_(param.clone(), weight_decay) | ||||
|                     else: | ||||
|                         # pyrefly: ignore  # bad-argument-type | ||||
|                         grad = grad.add(param, alpha=weight_decay) | ||||
|                 else: | ||||
|                     grad = grad.add(param, alpha=weight_decay) | ||||
| @ -444,6 +445,7 @@ def _single_tensor_adam( | ||||
|             device_beta1 = beta1 | ||||
|  | ||||
|         # Decay the first and second moment running average coefficient | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         exp_avg.lerp_(grad, 1 - device_beta1) | ||||
|  | ||||
|         # Nested if is necessary to bypass jitscript rules | ||||
| @ -692,6 +694,7 @@ def _multi_tensor_adam( | ||||
|             device_exp_avgs, device_grads, cast(float, 1 - device_beta1) | ||||
|         ) | ||||
|  | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         torch._foreach_mul_(device_exp_avg_sqs, beta2) | ||||
|  | ||||
|         # Due to the strictness of the _foreach_addcmul API, we can't have a single | ||||
|  | ||||
| @ -263,6 +263,7 @@ def _single_tensor_asgd( | ||||
|             ax.copy_(param) | ||||
|  | ||||
|         if capturable: | ||||
|             # pyrefly: ignore  # unsupported-operation | ||||
|             eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha)) | ||||
|             mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t))) | ||||
|         else: | ||||
|  | ||||
| @ -113,9 +113,11 @@ def _strong_wolfe( | ||||
|  | ||||
|         # compute new trial value | ||||
|         t = _cubic_interpolate( | ||||
|             # pyrefly: ignore  # index-error | ||||
|             bracket[0], | ||||
|             bracket_f[0], | ||||
|             bracket_gtd[0],  # type: ignore[possibly-undefined] | ||||
|             # pyrefly: ignore  # index-error | ||||
|             bracket[1], | ||||
|             bracket_f[1], | ||||
|             bracket_gtd[1], | ||||
| @ -151,6 +153,7 @@ def _strong_wolfe( | ||||
|  | ||||
|         if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: | ||||
|             # Armijo condition not satisfied or not lower than lowest point | ||||
|             # pyrefly: ignore  # unsupported-operation | ||||
|             bracket[high_pos] = t | ||||
|             bracket_f[high_pos] = f_new | ||||
|             bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)  # type: ignore[possibly-undefined] | ||||
| @ -160,14 +163,17 @@ def _strong_wolfe( | ||||
|             if abs(gtd_new) <= -c2 * gtd: | ||||
|                 # Wolfe conditions satisfied | ||||
|                 done = True | ||||
|             # pyrefly: ignore  # index-error | ||||
|             elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: | ||||
|                 # old high becomes new low | ||||
|                 # pyrefly: ignore  # unsupported-operation | ||||
|                 bracket[high_pos] = bracket[low_pos] | ||||
|                 bracket_f[high_pos] = bracket_f[low_pos] | ||||
|                 bracket_g[high_pos] = bracket_g[low_pos]  # type: ignore[possibly-undefined] | ||||
|                 bracket_gtd[high_pos] = bracket_gtd[low_pos] | ||||
|  | ||||
|             # new point becomes new low | ||||
|             # pyrefly: ignore  # unsupported-operation | ||||
|             bracket[low_pos] = t | ||||
|             bracket_f[low_pos] = f_new | ||||
|             bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)  # type: ignore[possibly-undefined] | ||||
| @ -252,6 +258,7 @@ class LBFGS(Optimizer): | ||||
|  | ||||
|     def _numel(self): | ||||
|         if self._numel_cache is None: | ||||
|             # pyrefly: ignore  # bad-assignment | ||||
|             self._numel_cache = sum( | ||||
|                 2 * p.numel() if torch.is_complex(p) else p.numel() | ||||
|                 for p in self._params | ||||
|  | ||||
| @ -1665,6 +1665,7 @@ class ReduceLROnPlateau(LRScheduler): | ||||
|             self.default_min_lr = None | ||||
|             self.min_lrs = list(min_lr) | ||||
|         else: | ||||
|             # pyrefly: ignore  # bad-assignment | ||||
|             self.default_min_lr = min_lr | ||||
|             self.min_lrs = [min_lr] * len(optimizer.param_groups) | ||||
|  | ||||
| @ -1724,6 +1725,7 @@ class ReduceLROnPlateau(LRScheduler): | ||||
|                     "of the `optimizer` param groups." | ||||
|                 ) | ||||
|             else: | ||||
|                 # pyrefly: ignore  # bad-assignment | ||||
|                 self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups) | ||||
|  | ||||
|         for i, param_group in enumerate(self.optimizer.param_groups): | ||||
| @ -1903,10 +1905,13 @@ class CyclicLR(LRScheduler): | ||||
|  | ||||
|         self.max_lrs = _format_param("max_lr", optimizer, max_lr) | ||||
|  | ||||
|         # pyrefly: ignore  # bad-assignment | ||||
|         step_size_up = float(step_size_up) | ||||
|         step_size_down = ( | ||||
|             # pyrefly: ignore  # bad-assignment | ||||
|             float(step_size_down) if step_size_down is not None else step_size_up | ||||
|         ) | ||||
|         # pyrefly: ignore  # unsupported-operation | ||||
|         self.total_size = step_size_up + step_size_down | ||||
|         self.step_ratio = step_size_up / self.total_size | ||||
|  | ||||
|  | ||||
| @ -62,6 +62,7 @@ def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]: | ||||
|     def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T: | ||||
|         import torch._dynamo | ||||
|  | ||||
|         # pyrefly: ignore  # unsupported-operation | ||||
|         self = cast(Optimizer, args[0])  # assume first positional arg is `self` | ||||
|         prev_grad = torch.is_grad_enabled() | ||||
|         try: | ||||
| @ -135,11 +136,13 @@ def _disable_dynamo_if_unsupported( | ||||
|             if torch.compiler.is_compiling() and ( | ||||
|                 not kwargs.get("capturable", False) | ||||
|                 and has_state_steps | ||||
|                 # pyrefly: ignore  # unsupported-operation | ||||
|                 and (arg := args[state_steps_ind]) | ||||
|                 and isinstance(arg, Sequence) | ||||
|                 and arg[0].is_cuda | ||||
|                 or ( | ||||
|                     "state_steps" in kwargs | ||||
|                     # pyrefly: ignore  # unsupported-operation | ||||
|                     and (kwarg := kwargs["state_steps"]) | ||||
|                     and isinstance(kwarg, Sequence) | ||||
|                     and kwarg[0].is_cuda | ||||
| @ -359,14 +362,18 @@ class Optimizer: | ||||
|  | ||||
|     _optimizer_step_pre_hooks: dict[int, OptimizerPreHook] | ||||
|     _optimizer_step_post_hooks: dict[int, OptimizerPostHook] | ||||
|     # pyrefly: ignore  # not-a-type | ||||
|     _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' | ||||
|     _optimizer_state_dict_post_hooks: ( | ||||
|         # pyrefly: ignore  # not-a-type | ||||
|         'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' | ||||
|     ) | ||||
|     _optimizer_load_state_dict_pre_hooks: ( | ||||
|         # pyrefly: ignore  # not-a-type | ||||
|         'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' | ||||
|     ) | ||||
|     _optimizer_load_state_dict_post_hooks: ( | ||||
|         # pyrefly: ignore  # not-a-type | ||||
|         'OrderedDict[int, Callable[["Optimizer"], None]]' | ||||
|     ) | ||||
|  | ||||
| @ -391,6 +398,7 @@ class Optimizer: | ||||
|         self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict) | ||||
|         self.param_groups: list[dict[str, Any]] = [] | ||||
|  | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         param_groups = list(params) | ||||
|         if len(param_groups) == 0: | ||||
|             raise ValueError("optimizer got an empty parameter list") | ||||
| @ -514,6 +522,7 @@ class Optimizer: | ||||
|                                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." | ||||
|                             ) | ||||
|  | ||||
|                 # pyrefly: ignore  # invalid-param-spec | ||||
|                 out = func(*args, **kwargs) | ||||
|                 self._optimizer_step_code() | ||||
|  | ||||
| @ -949,7 +958,14 @@ class Optimizer: | ||||
|             r"""Make a deep copy of value, casting all tensors to device of param.""" | ||||
|             if isinstance(value, torch.Tensor): | ||||
|                 return Optimizer._process_value_according_to_param_policy( | ||||
|                     param, value, param_id, param_groups, key | ||||
|                     # pyrefly: ignore  # bad-argument-type | ||||
|                     param, | ||||
|                     value, | ||||
|                     # pyrefly: ignore  # bad-argument-type | ||||
|                     param_id, | ||||
|                     # pyrefly: ignore  # bad-argument-type | ||||
|                     param_groups, | ||||
|                     key, | ||||
|                 ) | ||||
|             elif isinstance(value, dict): | ||||
|                 return { | ||||
| @ -960,6 +976,7 @@ class Optimizer: | ||||
|                 } | ||||
|             elif isinstance(value, Iterable): | ||||
|                 return type(value)( | ||||
|                     # pyrefly: ignore  # bad-argument-count | ||||
|                     _cast(param, v, param_id=param_id, param_groups=param_groups) | ||||
|                     for v in value | ||||
|                 )  # type: ignore[call-arg] | ||||
|  | ||||
| @ -322,6 +322,7 @@ def _single_tensor_radam( | ||||
|         rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2 | ||||
|  | ||||
|         def _compute_rect(): | ||||
|             # pyrefly: ignore  # unsupported-operation | ||||
|             return ( | ||||
|                 (rho_t - 4) | ||||
|                 * (rho_t - 2) | ||||
| @ -336,6 +337,7 @@ def _single_tensor_radam( | ||||
|             else: | ||||
|                 exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps) | ||||
|  | ||||
|             # pyrefly: ignore  # unsupported-operation | ||||
|             return (bias_correction2**0.5) / exp_avg_sq_sqrt | ||||
|  | ||||
|         # Compute the variance rectification term and update parameters accordingly | ||||
|  | ||||
| @ -337,6 +337,7 @@ def _single_tensor_sgd( | ||||
|     if not torch.jit.is_scripting(): | ||||
|         lr = _to_scalar(lr) | ||||
|  | ||||
|     # pyrefly: ignore  # bad-assignment | ||||
|     for i, param in enumerate(params): | ||||
|         grad = grads[i] if not maximize else -grads[i] | ||||
|  | ||||
| @ -347,6 +348,7 @@ def _single_tensor_sgd( | ||||
|                     # usually this is the differentiable path, which is why the param.clone() is needed | ||||
|                     grad = grad.addcmul_(param.clone(), weight_decay) | ||||
|                 else: | ||||
|                     # pyrefly: ignore  # bad-argument-type | ||||
|                     grad = grad.add(param, alpha=weight_decay) | ||||
|             else: | ||||
|                 grad = grad.add(param, alpha=weight_decay) | ||||
| @ -370,6 +372,7 @@ def _single_tensor_sgd( | ||||
|             if lr.requires_grad: | ||||
|                 param.addcmul_(grad, lr, value=-1) | ||||
|             else: | ||||
|                 # pyrefly: ignore  # bad-argument-type | ||||
|                 param.add_(grad, alpha=-lr) | ||||
|         else: | ||||
|             param.add_(grad, alpha=-lr) | ||||
| @ -430,10 +433,12 @@ def _multi_tensor_sgd( | ||||
|  | ||||
|             all_states_with_momentum_buffer = True | ||||
|             for i in range(len(device_momentum_buffer_list)): | ||||
|                 # pyrefly: ignore  # index-error | ||||
|                 if device_momentum_buffer_list[i] is None: | ||||
|                     all_states_with_momentum_buffer = False | ||||
|                     break | ||||
|                 else: | ||||
|                     # pyrefly: ignore  # index-error | ||||
|                     bufs.append(cast(Tensor, device_momentum_buffer_list[i])) | ||||
|  | ||||
|             if all_states_with_momentum_buffer: | ||||
| @ -441,12 +446,15 @@ def _multi_tensor_sgd( | ||||
|                 torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) | ||||
|             else: | ||||
|                 bufs = [] | ||||
|                 # pyrefly: ignore  # bad-assignment | ||||
|                 for i in range(len(device_momentum_buffer_list)): | ||||
|                     # pyrefly: ignore  # index-error | ||||
|                     if device_momentum_buffer_list[i] is None: | ||||
|                         buf = device_momentum_buffer_list[i] = momentum_buffer_list[ | ||||
|                             indices[i] | ||||
|                         ] = device_grads[i].detach().clone() | ||||
|                     else: | ||||
|                         # pyrefly: ignore  # index-error | ||||
|                         buf = cast(Tensor, device_momentum_buffer_list[i]) | ||||
|                         buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) | ||||
|  | ||||
|  | ||||
| @ -249,11 +249,13 @@ class AveragedModel(Module): | ||||
|     def update_parameters(self, model: Module): | ||||
|         """Update model parameters.""" | ||||
|         self_param = ( | ||||
|             # pyrefly: ignore  # bad-argument-type | ||||
|             itertools.chain(self.module.parameters(), self.module.buffers()) | ||||
|             if self.use_buffers | ||||
|             else self.parameters() | ||||
|         ) | ||||
|         model_param = ( | ||||
|             # pyrefly: ignore  # bad-argument-type | ||||
|             itertools.chain(model.parameters(), model.buffers()) | ||||
|             if self.use_buffers | ||||
|             else model.parameters() | ||||
| @ -300,8 +302,11 @@ class AveragedModel(Module): | ||||
|                 for p_averaged, p_model in zip(  # type: ignore[assignment] | ||||
|                     self_param_detached, model_param_detached | ||||
|                 ): | ||||
|                     # pyrefly: ignore  # missing-attribute | ||||
|                     n_averaged = self.n_averaged.to(p_averaged.device) | ||||
|                     # pyrefly: ignore  # missing-attribute | ||||
|                     p_averaged.detach().copy_( | ||||
|                         # pyrefly: ignore  # missing-attribute, bad-argument-type | ||||
|                         self.avg_fn(p_averaged.detach(), p_model, n_averaged) | ||||
|                     ) | ||||
|  | ||||
| @ -489,12 +494,14 @@ class SWALR(LRScheduler): | ||||
|         step = self._step_count - 1 | ||||
|         if self.anneal_epochs == 0: | ||||
|             step = max(1, step) | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) | ||||
|         prev_alpha = self.anneal_func(prev_t) | ||||
|         prev_lrs = [ | ||||
|             self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha) | ||||
|             for group in self.optimizer.param_groups | ||||
|         ] | ||||
|         # pyrefly: ignore  # no-matching-overload | ||||
|         t = max(0, min(1, step / max(1, self.anneal_epochs))) | ||||
|         alpha = self.anneal_func(t) | ||||
|         return [ | ||||
|  | ||||
		Reference in New Issue
	
	Block a user