diff --git a/pyrefly.toml b/pyrefly.toml index 73b0e9d28122..0279803aa5a5 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -22,14 +22,16 @@ project-excludes = [ # ==== to test Pyrefly on a specific directory, simply comment it out ==== "torch/_inductor/**", "torch/distributed/**", - "torch/nn/**", - "torch/_dynamo/**", # formatting issues "torch/linalg/__init__.py", "torch/package/importer.py", "torch/package/_package_pickler.py", "torch/jit/annotations.py", "torch/utils/data/datapipes/_typing.py", + "torch/nn/functional.py", + "torch/_export/utils.py", + "torch/fx/experimental/unification/multipledispatch/__init__.py", + "torch/nn/modules/__init__.py", # ==== "benchmarks/instruction_counts/main.py", "benchmarks/instruction_counts/definitions/setup.py", diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index f1c2db025383..221502ae3190 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -59,7 +59,6 @@ class TestBundledInputs(TestCase): # despite having nominally large bundled inputs. augmented_size = model_size(sm) - # pyrefly: ignore # missing-attribute self.assertLess(augmented_size, original_size + (1 << 12)) loaded = save_and_load(sm) @@ -67,15 +66,12 @@ class TestBundledInputs(TestCase): self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) self.assertEqual(len(inflated), len(samples)) - # pyrefly: ignore # missing-attribute self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) for idx, inp in enumerate(inflated): - # pyrefly: ignore # missing-attribute self.assertIsInstance(inp, tuple) self.assertEqual(len(inp), 1) - # pyrefly: ignore # missing-attribute self.assertIsInstance(inp[0], torch.Tensor) if idx != 5: # Strides might be important for benchmarking. @@ -144,7 +140,6 @@ class TestBundledInputs(TestCase): inflated = loaded.get_all_bundled_inputs() self.assertEqual(inflated, samples) - # pyrefly: ignore # missing-attribute self.assertTrue(loaded(*inflated[0]) == "first 1") def test_multiple_methods_with_inputs(self): @@ -192,7 +187,6 @@ class TestBundledInputs(TestCase): # Check running and size helpers - # pyrefly: ignore # missing-attribute self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) @@ -426,7 +420,6 @@ class TestBundledInputs(TestCase): augmented_size = model_size(sm) # assert the size has not increased more than 8KB - # pyrefly: ignore # missing-attribute self.assertLess(augmented_size, original_size + (1 << 13)) loaded = save_and_load(sm) diff --git a/test/test_complex.py b/test/test_complex.py index 972fe3f0fd1c..9941b68c1757 100644 --- a/test/test_complex.py +++ b/test/test_complex.py @@ -48,7 +48,7 @@ class TestComplexTensor(TestCase): def test_all(self, device, dtype): # issue: https://github.com/pytorch/pytorch/issues/120875 x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype) - # pyrefly: ignore # missing-attribute + self.assertTrue(torch.all(x)) @dtypes(*complex_types()) @@ -57,7 +57,7 @@ class TestComplexTensor(TestCase): x = torch.tensor( [0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype ) - # pyrefly: ignore # missing-attribute + self.assertFalse(torch.any(x)) @onlyCPU diff --git a/test/test_type_hints.py b/test/test_type_hints.py index c982ae19b6df..0aae54be9b63 100644 --- a/test/test_type_hints.py +++ b/test/test_type_hints.py @@ -142,7 +142,6 @@ class TestTypeHints(TestCase): ] ) if result != 0: - # pyrefly: ignore # missing-attribute self.fail(f"mypy failed:\n{stderr}\n{stdout}") diff --git a/test/test_type_info.py b/test/test_type_info.py index 48dc083fed1e..7f7ca388ebb9 100644 --- a/test/test_type_info.py +++ b/test/test_type_info.py @@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase): # Regression test for https://github.com/pytorch/pytorch/issues/124868 # If reference count is leaked this would be a set of 10 elements ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)} - # pyrefly: ignore # missing-attribute + self.assertLess(len(ref_cnt), 3) self.assertEqual(torch.float64.to_complex(), torch.complex128) @@ -136,7 +136,7 @@ class TestDTypeInfo(TestCase): # Regression test for https://github.com/pytorch/pytorch/issues/124868 # If reference count is leaked this would be a set of 10 elements ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)} - # pyrefly: ignore # missing-attribute + self.assertLess(len(ref_cnt), 3) self.assertEqual(torch.complex128.to_real(), torch.double) diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 561acf62f785..44ce8bcf939d 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -53,6 +53,8 @@ from .eval_frame import ( OptimizedModule, reset_code, ) + +# pyrefly: ignore # deprecated from .external_utils import is_compiling from .mutation_guard import GenerationTracker from .pgo import reset_code_state diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 9b000ee926a1..69ffb830c945 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -95,6 +95,7 @@ class ModIndex(torch.autograd.Function): generate_vmap_rule = True @staticmethod + # pyrefly: ignore # bad-override def forward(x: Tensor, indices: list[Tensor]) -> Tensor: return torch.ops.aten.index(x, indices) @@ -242,6 +243,7 @@ def _trace_wrapped_functionalized(ctx: Any, *args: Any, **kwargs: Any) -> Any: def autograd_function_backward_rewritten(original_backward: Any) -> Any: def new_backward(ctx: Any, *grads: Any) -> Any: + # pyrefly: ignore # bad-assignment grads = [g.contiguous() for g in grads] return original_backward(ctx, *grads) diff --git a/torch/_dynamo/aot_compile.py b/torch/_dynamo/aot_compile.py index 35a1fa69c954..142e244067ba 100644 --- a/torch/_dynamo/aot_compile.py +++ b/torch/_dynamo/aot_compile.py @@ -89,6 +89,7 @@ class AOTCompiledFunction: **import_sources, self._artifacts.backend_id: self._artifacts.compiled_fn, } + # pyrefly: ignore # read-only self.fn = types.FunctionType( self._artifacts.bytecode, f_globals, closure=self._artifacts.closure ) diff --git a/torch/_dynamo/backends/cudagraphs.py b/torch/_dynamo/backends/cudagraphs.py index f8599d393833..d6775d0841d8 100644 --- a/torch/_dynamo/backends/cudagraphs.py +++ b/torch/_dynamo/backends/cudagraphs.py @@ -206,6 +206,7 @@ def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) assert manager is not None def fn(inputs: list[Any]) -> Any: + # pyrefly: ignore # missing-attribute manager.set_to_running_backward() return aot_model(inputs) diff --git a/torch/_dynamo/backends/tvm.py b/torch/_dynamo/backends/tvm.py index 7e2ab19bb9c0..4820916c1212 100644 --- a/torch/_dynamo/backends/tvm.py +++ b/torch/_dynamo/backends/tvm.py @@ -77,16 +77,19 @@ def tvm( opt_level = options.get("opt_level", 3) if scheduler == "auto_scheduler": + # pyrefly: ignore # import-error from tvm import auto_scheduler log_file = tempfile.NamedTemporaryFile() + # pyrefly: ignore # bad-argument-type if not os.path.exists(log_file): tasks, task_weights = auto_scheduler.extract_tasks( mod["main"], params, target ) if len(tasks) != 0: tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + # pyrefly: ignore # bad-argument-type if not os.path.exists(log_file): assert trials > 0 tune_option = auto_scheduler.TuningOptions( @@ -97,7 +100,9 @@ def tvm( try: tuner.tune(tune_option) except Exception: + # pyrefly: ignore # bad-argument-type if os.path.exists(log_file): + # pyrefly: ignore # bad-argument-type os.unlink(log_file) raise @@ -107,6 +112,7 @@ def tvm( ): lib = relay.build(mod, target=target, params=params) elif scheduler == "meta_schedule": + # pyrefly: ignore # import-error from tvm import meta_schedule as ms with tempfile.TemporaryDirectory() as work_dir: diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py index 3ccbd56bfada..0e98f873e479 100644 --- a/torch/_dynamo/bytecode_analysis.py +++ b/torch/_dynamo/bytecode_analysis.py @@ -37,6 +37,7 @@ if sys.version_info >= (3, 11): TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"]) else: TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"]) +# pyrefly: ignore # unsupported-operation if (3, 12) <= sys.version_info < (3, 14): TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"]) if sys.version_info >= (3, 13): diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index 48d667319a11..230de4964154 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -903,6 +903,7 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None: inst.arg = abs( int(target.offset - inst.offset - instruction_size(inst)) ) + # pyrefly: ignore # unsupported-operation inst.arg //= 2 inst.argval = target.offset inst.argrepr = f"to {target.offset}" @@ -1354,6 +1355,7 @@ def update_offsets(instructions: Sequence[Instruction]) -> None: offset = 0 for inst in instructions: inst.offset = offset + # pyrefly: ignore # unsupported-operation offset += instruction_size(inst) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 438af14886bb..f4850f7e5e9b 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -464,6 +464,7 @@ def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]: try: prof.enable() start_ts = time.time() + # pyrefly: ignore # bad-argument-type retval = prof.runcall(func, *args, **kwargs) profile_latency = time.time() - start_ts prof.disable() @@ -957,6 +958,7 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]: if isinstance(mod, torch.nn.Module): mod = mod.forward if hasattr(mod, "__self__"): + # pyrefly: ignore # missing-attribute return mod.__func__, mod.__self__ elif inspect.isfunction(mod): return mod, None @@ -1096,6 +1098,7 @@ def _fullgraph_capture_frame( while cur_exn.__cause__ is not None: cur_exn.__cause__.with_traceback(None) cur_exn = cur_exn.__cause__ + # pyrefly: ignore # invalid-inheritance raise e.with_traceback(None) from e.__cause__ # User compiler error return CaptureOutput( @@ -1119,6 +1122,7 @@ def compile_frame( # type: ignore[return] frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None, distributed_state: Optional[DistributedState] = None, package: Optional[CompilePackage] = None, + # pyrefly: ignore # bad-return ) -> DynamoOutput: """ A helper function taking a frame and backend, then return the generated bytecode diff --git a/torch/_dynamo/create_parameter_op.py b/torch/_dynamo/create_parameter_op.py index ded3ef75ed1d..63f6704370b8 100644 --- a/torch/_dynamo/create_parameter_op.py +++ b/torch/_dynamo/create_parameter_op.py @@ -20,6 +20,7 @@ allowed to compute gradients on). class TracableCreateParameter(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter: assert not tensor.requires_grad return placeholder.set_(tensor) diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index d7905cd45e0a..493533f9ae8b 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -879,6 +879,7 @@ def aot_graph_input_parser( data_type, shape_str = match.groups() shape = tuple(shape_str.split(",")) dtype = dtype_map[data_type] + # pyrefly: ignore # bad-argument-type kwargs[param] = gen_tensor(shape, dtype) match = re.search(sym_shape_regex, annotation) @@ -892,6 +893,7 @@ def aot_graph_input_parser( attr_name, data_type, shape_str, _ = match.groups() shape = tuple(shape_str.split(",")) dtype = dtype_map[data_type] + # pyrefly: ignore # bad-argument-type setattr(container, attr_name, gen_tensor(shape, dtype)) return kwargs diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index bb66e79b6557..0a1066cf8dc4 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -95,6 +95,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined] nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined] nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + # pyrefly: ignore # bad-return return nonrecursive_disable_wrapper if fn is None: @@ -306,6 +307,7 @@ def forbid_in_graph(fn: Any) -> Any: if isinstance(fn, (list, tuple)): return [forbid_in_graph(x) for x in fn] assert callable(fn), "forbid_in_graph applies only to callables" + # pyrefly: ignore # missing-attribute fn._dynamo_forbidden = True return fn @@ -653,21 +655,28 @@ def mark_dynamic( if isinstance(index, int): if not hasattr(t, "_dynamo_dynamic_indices"): + # pyrefly: ignore # missing-attribute t._dynamo_dynamic_indices = set() + # pyrefly: ignore # missing-attribute t._dynamo_dynamic_range = set() + # pyrefly: ignore # missing-attribute t._dynamo_hint_overrides = {} if not hasattr(t, "_specialize_on"): + # pyrefly: ignore # missing-attribute t._specialize_on = {} if hint_override: + # pyrefly: ignore # missing-attribute t._dynamo_hint_overrides[index] = hint_override # TODO(voz): Should we bounds check? + # pyrefly: ignore # missing-attribute t._dynamo_dynamic_indices.add(index) t._dynamo_dynamic_range.add(_DimRange(index, min, max)) # type: ignore[arg-type] # FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies: # TypeError: 'Attribute' object does not support item assignment + # pyrefly: ignore # missing-attribute if isinstance(t._specialize_on, dict): t._specialize_on[index] = specialize_on if specialize_on is not None else [] @@ -692,8 +701,10 @@ def maybe_mark_dynamic(t: Any, index: Union[int, list[Any], tuple[Any]]) -> None if isinstance(index, int): if not hasattr(t, "_dynamo_weak_dynamic_indices"): + # pyrefly: ignore # missing-attribute t._dynamo_weak_dynamic_indices = set() # TODO(voz): Should we bounds check? + # pyrefly: ignore # missing-attribute t._dynamo_weak_dynamic_indices.add(index) return @@ -745,8 +756,11 @@ def mark_static( # TODO: Make this configurable via a supported public API _apply_func_to_inner_tensors_of_same_dim(mark_static, t, index) + # pyrefly: ignore # bad-argument-type if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module): + # pyrefly: ignore # missing-attribute t._dynamo_marked_static = True + # pyrefly: ignore # bad-return return t if not isinstance(t, torch.Tensor): diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index e463023caa77..c6eb87c42cb5 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -205,6 +205,7 @@ class CudaInterface(DeviceInterface): Event = torch.cuda.Event # type: ignore[assignment] Stream = torch.cuda.Stream # type: ignore[assignment] + # pyrefly: ignore # bad-override class Worker: @staticmethod def set_device(device: int) -> None: @@ -240,6 +241,7 @@ class CudaInterface(DeviceInterface): set_device = staticmethod(torch.cuda.set_device) device_count = staticmethod(torch.cuda.device_count) stream = staticmethod(torch.cuda.stream) # type: ignore[assignment] + # pyrefly: ignore # bad-override current_stream = staticmethod(torch.cuda.current_stream) set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment] _set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment] @@ -300,6 +302,7 @@ class MtiaInterface(DeviceInterface): Event = torch.mtia.Event # type: ignore[assignment] Stream = torch.mtia.Stream # type: ignore[assignment] + # pyrefly: ignore # bad-override class Worker: @staticmethod def set_device(device: int) -> None: @@ -335,6 +338,7 @@ class MtiaInterface(DeviceInterface): set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment] device_count = staticmethod(torch.mtia.device_count) stream = staticmethod(torch.mtia.stream) # type: ignore[assignment] + # pyrefly: ignore # bad-override current_stream = staticmethod(torch.mtia.current_stream) set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment] _set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment] @@ -381,6 +385,7 @@ class XpuInterface(DeviceInterface): Event = torch.xpu.Event # type: ignore[assignment] Stream = torch.xpu.Stream # type: ignore[assignment] + # pyrefly: ignore # bad-override class Worker: @staticmethod def set_device(device: int) -> None: @@ -416,6 +421,7 @@ class XpuInterface(DeviceInterface): set_device = staticmethod(torch.xpu.set_device) device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type] stream = staticmethod(torch.xpu.stream) # type: ignore[assignment] + # pyrefly: ignore # bad-override current_stream = staticmethod(torch.xpu.current_stream) set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment] _set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment] @@ -458,6 +464,7 @@ class CpuDeviceProperties: class CpuInterface(DeviceInterface): + # pyrefly: ignore # bad-override class Event(torch.Event): def __init__(self, enable_timing: bool = True) -> None: self.time = 0.0 @@ -468,6 +475,7 @@ class CpuInterface(DeviceInterface): def record(self, stream: Any = None) -> None: self.time = time.perf_counter() + # pyrefly: ignore # bad-override class Worker: @staticmethod def get_device_properties( @@ -543,6 +551,7 @@ class MpsInterface(DeviceInterface): def synchronize(device: torch.types.Device = None) -> None: torch.mps.synchronize() + # pyrefly: ignore # bad-override class Worker: @staticmethod def get_device_properties(device: torch.types.Device = None) -> Any: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 8b55ac48cca2..c4fa1e4d1545 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -484,6 +484,7 @@ class OptimizedModule(torch.nn.Module): self._initialize() @property + # pyrefly: ignore # bad-override def training(self) -> bool: return self._orig_mod.training @@ -892,6 +893,7 @@ class _TorchDynamoContext: while cur_exn.__cause__ is not None: cur_exn.__cause__.with_traceback(None) cur_exn = cur_exn.__cause__ + # pyrefly: ignore # invalid-inheritance raise e.with_traceback(None) from e.__cause__ # User compiler error except ShortenTraceback as e: # Failures in the backend likely don't have useful @@ -1020,7 +1022,10 @@ class OptimizeContext(_TorchDynamoContext): assert rebuild_ctx is not None compiler_fn = rebuild_ctx() ctx = torch._dynamo.compiled_autograd._enable( - compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False + compiler_fn, + # pyrefly: ignore # bad-argument-type + dynamic=_dynamic, + ignore_active_disable_ctx=False, ) ctx.__enter__() return functools.partial(ctx.__exit__, None, None, None) @@ -1083,6 +1088,7 @@ class DisableContext(_TorchDynamoContext): cls_obj.__call__ = self(cls_obj.__call__) if issubclass(cls_obj, torch.nn.Module): # NN module variable tracker directly inlines the _call_impl. Disable it. + # pyrefly: ignore # missing-attribute cls_obj._call_impl = self(cls_obj._call_impl) return cls_obj @@ -1988,6 +1994,7 @@ def export( path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any] ) -> Any: if isinstance(t, torch.Tensor): + # pyrefly: ignore # missing-attribute return ambient_fake_mode.from_tensor(t, static_shapes=True) elif isinstance(t, _IntWrapper): if ( @@ -2068,8 +2075,11 @@ def export( ) and not trace_rules.check(call_to_inspect) ): + # pyrefly: ignore # unbound-name dim_constraints.solve() + # pyrefly: ignore # unbound-name forced_specializations = dim_constraints.forced_specializations() + # pyrefly: ignore # unbound-name msg = dim_constraints.prettify_results( original_signature, dynamic_shapes, @@ -2090,9 +2100,11 @@ def export( ) # Error if we have any constraints on static values + # pyrefly: ignore # unbound-name for k in shape_env.var_to_range.keys(): if isinstance(k, sympy.Integer): constraint_violation_error = ConstraintViolationError( + # pyrefly: ignore # unbound-name f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" "It appears that you're trying to set a constraint on a " f"value which we evaluated to have a static value of {k}. " diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index e69b768ba374..f45b3647df45 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -369,6 +369,7 @@ def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedExc observed_exception_map[exc_type] = type( # type: ignore[assignment] f"Observed{name}Error", (ObservedException,), {} ) + # pyrefly: ignore # index-error return observed_exception_map[exc_type] diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 2ff3f6752f56..75d2020ce56f 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -96,7 +96,9 @@ def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]: args, kwargs = pytree.tree_map_only( torch.Tensor, lambda x: x.numpy(), (args, kwargs) ) + # pyrefly: ignore # invalid-param-spec out = f(*args, **kwargs) + # pyrefly: ignore # missing-attribute return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out) return wrap diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 43d0dbd544e2..c3c13973c4bb 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -250,6 +250,7 @@ class DynamoGraphTransformer(torch.fx.Transformer): else: placeholder.node.meta["val"] = self.flat_inputs[i] + # pyrefly: ignore # unsupported-operation self.new_input_nodes[i] = placeholder def _create_placeholder_mapping(self) -> None: @@ -324,12 +325,18 @@ class DynamoGraphTransformer(torch.fx.Transformer): # Copy module metadata like the original implementation if hasattr(self.module, "meta"): + # pyrefly: ignore # unsupported-operation if "dynamo_flat_name_to_original_fqn" in self.module.meta: + # pyrefly: ignore # index-error result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ + # pyrefly: ignore # index-error "dynamo_flat_name_to_original_fqn" ] + # pyrefly: ignore # unsupported-operation if "dynamo_compile_id" in self.module.meta: + # pyrefly: ignore # index-error result_gm.meta["dynamo_compile_id"] = self.module.meta[ + # pyrefly: ignore # index-error "dynamo_compile_id" ] @@ -361,8 +368,11 @@ def _suggest_or_raise_constraint_violation( torch._ops.OpOverloadPacket | torch._ops.OpOverload, ) ): + # pyrefly: ignore # unbound-name dim_constraints.solve() + # pyrefly: ignore # unbound-name forced_specializations = dim_constraints.forced_specializations() + # pyrefly: ignore # unbound-name msg = dim_constraints.prettify_results( inspect.signature(orig_callable), # type: ignore[attr-defined] dynamic_shapes, @@ -383,9 +393,11 @@ def _suggest_or_raise_constraint_violation( ) # Error if we have any constraints on static values + # pyrefly: ignore # unbound-name for k in shape_env.var_to_range.keys(): if isinstance(k, sympy.Integer): constraint_violation_error = ConstraintViolationError( + # pyrefly: ignore # unbound-name f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" "It appears that you're trying to set a constraint on a " f"value which we evaluated to have a static value of {k}. " diff --git a/torch/_dynamo/graph_region_tracker.py b/torch/_dynamo/graph_region_tracker.py index c1463d290bc9..c16ce22a1ded 100644 --- a/torch/_dynamo/graph_region_tracker.py +++ b/torch/_dynamo/graph_region_tracker.py @@ -320,6 +320,7 @@ class GraphRegionTracker: if len(group) > 1: region_group = [] min_rank = math.inf + # pyrefly: ignore # bad-assignment for node in group: # some nodes aren't in the topo ranking? if node in topological_ranking: diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index e60d1ceb72e1..eed59d14ef2e 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -640,6 +640,7 @@ class GuardManagerWrapper: if isinstance(guard, RelationalGuard): if guard not in self.printed_relational_guards: self.printed_relational_guards.add(guard) + # pyrefly: ignore # bad-argument-type body.writelines(self.get_guard_lines(guard)) else: body.writelines( @@ -700,6 +701,7 @@ class GuardManagerWrapper: for guard in mgr.get_leaf_guards(): if isinstance(guard, RelationalGuard): if guard not in relational_guards_seen: + # pyrefly: ignore # bad-argument-type self.code_parts.extend(get_code_parts(guard)) relational_guards_seen.add(guard) else: @@ -716,6 +718,7 @@ def from_numpy(a: Any) -> torch.Tensor: # Re-enable torch function since we disable it on leaf guards # we need it to properly construct the tensor if a default device is set with torch.overrides._enable_torch_function(): + # pyrefly: ignore # missing-attribute return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a @@ -729,6 +732,7 @@ def uninteresting_files() -> set[str]: from torch._dynamo.polyfills.loader import POLYFILLED_MODULES + # pyrefly: ignore # bad-argument-type mods.extend(POLYFILLED_MODULES) return {inspect.getfile(m) for m in mods} @@ -2205,6 +2209,7 @@ class GuardBuilder(GuardBuilderBase): return # Python math library doesn't support complex nan, so we need to use numpy + # pyrefly: ignore # missing-attribute if istype(val, complex) and np.isnan(val): code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"] self._set_guard_export_info(guard, code) @@ -2495,6 +2500,7 @@ class GuardBuilder(GuardBuilderBase): # sources for the corresponding tensor dimension. return [ TensorPropertySource(source, TensorProperty.SIZE, dim) + # pyrefly: ignore # missing-attribute for source in output_graph.tracked_fakes_id_to_source[t_id] ] @@ -2531,6 +2537,7 @@ class GuardBuilder(GuardBuilderBase): equalities_inputs = None def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: + # pyrefly: ignore # missing-attribute return output_graph.shape_env.produce_guards_verbose( [a.fake for a in fs], # type: ignore[misc] [a.source for a in fs], @@ -2538,6 +2545,7 @@ class GuardBuilder(GuardBuilderBase): equalities_inputs=equalities_inputs, source_ref=self.source_ref, # Export keeps static. + # pyrefly: ignore # missing-attribute ignore_static=(not output_graph.export), langs=langs, ) @@ -2599,7 +2607,9 @@ class GuardBuilder(GuardBuilderBase): if not python_fallback: assert cpp_code_parts # type: ignore[possibly-undefined] code_parts, source_to_symbol = ( + # pyrefly: ignore # unbound-name cpp_code_parts.exprs, + # pyrefly: ignore # unbound-name, missing-attribute cpp_code_parts.source_to_symbol, ) @@ -2630,7 +2640,9 @@ class GuardBuilder(GuardBuilderBase): assert cpp_code_parts # type: ignore[possibly-undefined] code_parts, source_to_symbol = ( + # pyrefly: ignore # unbound-name cpp_code_parts.exprs, + # pyrefly: ignore # unbound-name, missing-attribute cpp_code_parts.source_to_symbol, ) @@ -3240,6 +3252,7 @@ class GuardsStatePickler(pickle.Pickler): assert _.__closure__ is not None return _.__closure__[0] + # pyrefly: ignore # bad-override def reducer_override( self, obj: Any ) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]: @@ -4072,10 +4085,13 @@ class CheckFunctionManager: and (cache_entry := self.guard_manager.cache_entry) is not None and (extra_state := self.guard_manager.extra_state) is not None ): + # pyrefly: ignore # unbound-name assert isinstance(cache_entry, CacheEntry) + # pyrefly: ignore # unbound-name assert isinstance(extra_state, ExtraState) reason = f"Cache line invalidated because {obj_str} got deallocated" deleted_guard_manager = DeletedGuardManagerWrapper(reason) + # pyrefly: ignore # unbound-name extra_state.invalidate(cache_entry, deleted_guard_manager) self.guard_manager = deleted_guard_manager diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 9f0e40ffbf9f..9f98440a4dba 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -707,6 +707,7 @@ class OutputGraph(OutputGraphCommon): self.backward_state_proxy: Optional[torch.fx.Proxy] = None self.backward_state_var: Optional[str] = None + # pyrefly: ignore # bad-override self.name_of_builtins_dict_key_in_fglobals: str = ( self.install_builtins_dict_in_fglobals() ) @@ -1146,6 +1147,7 @@ class OutputGraph(OutputGraphCommon): vt = self.root_tx.output.side_effects.track_object_existing(target, vt) assert "tensor_dict" not in vt.as_proxy().node.meta + # pyrefly: ignore # bad-argument-type vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target) return vt @@ -1157,6 +1159,7 @@ class OutputGraph(OutputGraphCommon): install_guard(source.make_guard(GuardBuilder.NN_MODULE)) def wrap_name(module_key: str) -> VariableTracker: + # pyrefly: ignore # bad-argument-type return NNModuleVariable(type(target), module_key, target, **options) else: @@ -1970,7 +1973,9 @@ class OutputGraph(OutputGraphCommon): tx = self.root_tx assert tx is not None if (ds := tx.distributed_state) is not None and ds.all_states is None: + # pyrefly: ignore # unbound-name compile_pg = ds.compile_pg + # pyrefly: ignore # unbound-name log.info("compiler_collective %s", ds.local_state) torch._logging.trace_structured( "artifact", @@ -1978,6 +1983,7 @@ class OutputGraph(OutputGraphCommon): "name": "compiler_collective", "encoding": "string", }, + # pyrefly: ignore # unbound-name payload_fn=lambda: ds.local_state.render(), ) device_types = compile_pg._device_types @@ -1991,7 +1997,9 @@ class OutputGraph(OutputGraphCommon): dynamo_timed("compiler_collective", log_pt2_compile_event=True), ): all_states: list[Any] = [None] * compile_pg.size() + # pyrefly: ignore # unbound-name dist.all_gather_object(all_states, ds.local_state, group=compile_pg) + # pyrefly: ignore # unbound-name ds.all_states = all_states # Clear speculation log, because are tracing may diverge due to # this information from the compiler collective @@ -2321,6 +2329,7 @@ class OutputGraph(OutputGraphCommon): }, ) + # pyrefly: ignore # unbound-name return compiled_fn def dedup_pass(self) -> dict[str, torch.fx.GraphModule]: @@ -2375,6 +2384,7 @@ class OutputGraph(OutputGraphCommon): isinstance(b, torch.SymBool) and (r := b.node.maybe_as_bool()) is not None ): + # pyrefly: ignore # unbound-name return r # TODO: We can also technically remove all cases when the input # doesn't have unbacked inputs, since it's all in the ShapeEnv @@ -2740,6 +2750,7 @@ def check_pt2_compliant_op( hints=[], ) + # pyrefly: ignore # unbound-name op = getattr(target, overload) if torch.Tag.pt2_compliant_tag in op.tags: encountered_compliant_op(op) @@ -2747,6 +2758,7 @@ def check_pt2_compliant_op( encountered_non_compliant_op( op, f"Encountered the torch.ops.OpOverloadPacket {target} " + # pyrefly: ignore # unbound-name f"which resolves to the overload ({overload}) that is " f"not PT2 compliant.", ) @@ -2767,6 +2779,7 @@ class LazyProxy: **kwargs: P.kwargs, ) -> None: self.tracer = tracer + # pyrefly: ignore # invalid-type-var self.fn = fn self.args = args self.kwargs = kwargs diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index ffb6e550f978..6acc89fffac9 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -319,6 +319,7 @@ def _get_code_source(code: types.CodeType) -> tuple[str, str]: code_source = _find_code_source(toplevel) if code_source is None: _raise_resolution_error(code, toplevel) + # pyrefly: ignore # missing-attribute return toplevel.__qualname__, code_source.strip(".") @@ -593,9 +594,11 @@ class CompilePackage: f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})" ) + # pyrefly: ignore # bad-assignment self._source_info = dynamo.source_info main, *codes = dynamo.codes + # pyrefly: ignore # bad-assignment self._codes = {self._innermost_fn.__code__: main} for code in codes: self._codes[SerializedCode.to_code_object(code.python_code)] = code @@ -603,6 +606,7 @@ class CompilePackage: self._add_function( self._innermost_fn.__code__, self._innermost_fn.__module__ ) + # pyrefly: ignore # bad-assignment self._initialized = True def _add_function( @@ -746,6 +750,7 @@ class CompilePackage: for name in names: module.__dict__.pop(name) + # pyrefly: ignore # bad-assignment self._installed_globals = {} _reset_precompile_entries(self._innermost_fn.__code__) diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 905b7a11c6e8..ad068958dcb3 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -167,6 +167,7 @@ class CodeId: @dataclasses.dataclass class CodeState: automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field( + # pyrefly: ignore # unbound-name default_factory=lambda: defaultdict(FrameStateSizeEntry) ) @@ -851,6 +852,7 @@ def get_code_state() -> defaultdict[CodeId, CodeState]: not _CODE_STATE and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None ): + # pyrefly: ignore # unbound-name extra_read_key = get_extra_cache_key(sticky_read) if extra_read_key is not None: get_extra_remote_code_state(extra_read_key) diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index ff87743d5be5..954fbd994e75 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -196,6 +196,7 @@ def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: @overload +# pyrefly: ignore # inconsistent-overload def zip_longest( iter1: Iterable[_T1], /, @@ -205,6 +206,7 @@ def zip_longest( @overload +# pyrefly: ignore # inconsistent-overload def zip_longest( iter1: Iterable[_T1], iter2: Iterable[_T2], @@ -213,6 +215,7 @@ def zip_longest( @overload +# pyrefly: ignore # inconsistent-overload def zip_longest( iter1: Iterable[_T1], iter2: Iterable[_T2], @@ -223,6 +226,7 @@ def zip_longest( @overload +# pyrefly: ignore # inconsistent-overload def zip_longest( iter1: Iterable[_T], iter2: Iterable[_T], @@ -233,6 +237,7 @@ def zip_longest( @overload +# pyrefly: ignore # inconsistent-overload def zip_longest( iter1: Iterable[_T], iter2: Iterable[_T], diff --git a/torch/_dynamo/polyfills/operator.py b/torch/_dynamo/polyfills/operator.py index 4ce889b297c9..4a24ce20bf21 100644 --- a/torch/_dynamo/polyfills/operator.py +++ b/torch/_dynamo/polyfills/operator.py @@ -30,10 +30,12 @@ _Us = TypeVarTuple("_Us") @overload +# pyrefly: ignore # inconsistent-overload def attrgetter(attr: str, /) -> Callable[[Any], _U]: ... @overload +# pyrefly: ignore # inconsistent-overload def attrgetter( attr1: str, attr2: str, /, *attrs: str ) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... @@ -68,10 +70,12 @@ def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]: @overload +# pyrefly: ignore # inconsistent-overload def itemgetter(item: _T, /) -> Callable[[Any], _U]: ... @overload +# pyrefly: ignore # inconsistent-overload def itemgetter( item1: _T1, item2: _T2, /, *items: Unpack[_Ts] ) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... diff --git a/torch/_dynamo/polyfills/os.py b/torch/_dynamo/polyfills/os.py index 5388816b8267..98adc5582d0f 100644 --- a/torch/_dynamo/polyfills/os.py +++ b/torch/_dynamo/polyfills/os.py @@ -17,6 +17,7 @@ __all__ = ["fspath"] @substitute_in_graph(os.fspath, can_constant_fold_through=True) def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: if isinstance(path, (str, bytes)): + # pyrefly: ignore # bad-return return path path_type = type(path) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index dfad40de4b08..9f2b7d9636d4 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -171,6 +171,7 @@ if python_pytree._cxx_pytree_dynamo_traceable: or optree.is_namedtuple_class(treespec.type) or optree.is_structseq_class(treespec.type) ): + # pyrefly: ignore # bad-return return treespec._unflatten_func( treespec._metadata, children_representations, diff --git a/torch/_dynamo/profiler.py b/torch/_dynamo/profiler.py index 2055507f72a4..8c0c862f3404 100644 --- a/torch/_dynamo/profiler.py +++ b/torch/_dynamo/profiler.py @@ -49,8 +49,11 @@ class ProfileMetrics: if isinstance(other, int): other = ProfileMetrics(other, other, other) return ProfileMetrics( + # pyrefly: ignore # no-matching-overload self.microseconds / max(1, other.microseconds), + # pyrefly: ignore # bad-argument-type self.operators / max(1, other.operators), + # pyrefly: ignore # bad-argument-type self.fusions / max(1, other.fusions), ) diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 998acc739775..c512ce891700 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -370,13 +370,16 @@ isolate_fails_code_str = None try: if isinstance(kernel, Autotuner): + # pyrefly: ignore # missing-attribute if isinstance(kernel.fn, Heuristics): model_str += "ERROR: Repro will not work as intended, " model_str += "triton.runtime.autotuner.Heuristics is not currently supported\n" break config_strs = [] + # pyrefly: ignore # missing-attribute for kernel_config in kernel.configs: + # pyrefly: ignore # bad-argument-type config_strs.append(f"""triton.Config( {str(kernel_config.kwargs)}, num_warps={kernel_config.num_warps}, @@ -394,8 +397,10 @@ isolate_fails_code_str = None """).strip() model_str += "\n@triton.jit\n" + # pyrefly: ignore # missing-attribute src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src fn_name = ( + # pyrefly: ignore # missing-attribute kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name @@ -409,7 +414,9 @@ isolate_fails_code_str = None model_str += "ERROR: Repro will not work as intended, " model_str += f"User defined triton kernel exception: {e}\n" + # pyrefly: ignore # unbound-name if len(kernel_side_table.constant_args) > 0: + # pyrefly: ignore # unbound-name model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n" model_str += NNModuleToString.convert(gm) @@ -420,8 +427,10 @@ isolate_fails_code_str = None # Extract from graph placeholders and their corresponding arguments placeholder_targets = fx_placeholder_targets(gm) for placeholder, arg in zip(placeholder_targets, args): + # pyrefly: ignore # unbound-name if isinstance(arg, (int, torch.SymInt)): writer.symint(placeholder, arg) + # pyrefly: ignore # unbound-name elif isinstance(arg, torch.Tensor): # TODO: improve these names with FQN writer.tensor(placeholder, arg) @@ -431,16 +440,20 @@ isolate_fails_code_str = None writer.unsupported(placeholder, arg) # Extract symbolic variables from the same arguments + # pyrefly: ignore # unbound-name if isinstance(arg, torch.SymInt): sym_name = str(arg.node) if arg.node.hint is not None: used_syms[sym_name] = arg.node.hint + # pyrefly: ignore # unbound-name elif isinstance(arg, torch.Tensor): # Extract symbolic variables from tensor shapes and strides for dim in arg.shape: + # pyrefly: ignore # unbound-name if isinstance(dim, torch.SymInt) and dim.node.hint is not None: used_syms[str(dim.node)] = dim.node.hint for stride in arg.stride(): + # pyrefly: ignore # unbound-name if isinstance(stride, torch.SymInt) and stride.node.hint is not None: used_syms[str(stride.node)] = stride.node.hint @@ -758,6 +771,7 @@ def repro_common( # TODO: speed this up mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args) + # pyrefly: ignore # bad-assignment torch._inductor.config.generate_intermediate_hooks = True return mod, args diff --git a/torch/_dynamo/repro/aoti.py b/torch/_dynamo/repro/aoti.py index e0aaf4caee47..eae021752fd9 100644 --- a/torch/_dynamo/repro/aoti.py +++ b/torch/_dynamo/repro/aoti.py @@ -301,6 +301,7 @@ def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]: def repro_common( options: Any, exported_program: ExportedProgram ) -> tuple[torch.fx.GraphModule, Any, Any]: + # pyrefly: ignore # bad-assignment torch._inductor.config.generate_intermediate_hooks = True mod = exported_program.module(check_guards=False) args, kwargs = exported_program.example_inputs @@ -422,6 +423,7 @@ def repro_minify( ) -> bool: # Need to export first so the in_spec and out_spec are populated tuple_inputs = tuple(flat_example_inputs) + # pyrefly: ignore # bad-assignment gm = export_for_aoti_minifier( gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error ) diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 613996d28dc2..0021152fc704 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -102,6 +102,7 @@ def _bytecode_from_template_with_split( def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None: # NOTE: Make sure this name matches what is generated by symbolic_convert:import_source # on torch._dynamo.utils. + # pyrefly: ignore # unknown-name global __import_torch_dot__dynamo_dot_utils try: dummy @@ -555,6 +556,7 @@ class ContinueExecutionCache: # remap original instructions' exception table entries if old_hook_target_remap: + # pyrefly: ignore # unbound-name assert is_py311_plus for inst in instructions: if ( diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 80b22e55227c..4e45dc7446d2 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -696,6 +696,7 @@ class SideEffects: cg.add_cache(var) var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] elif var.source is None: + # pyrefly: ignore # bad-assignment var.source = LocalCellSource(var.local_name) elif isinstance(var, variables.TensorVariable): # NOTE: for historical reasons we never assigned local sources @@ -732,6 +733,7 @@ class SideEffects: if isinstance(var, variables.UserDefinedObjectVariable): def load_new_method() -> None: + # pyrefly: ignore # missing-attribute assert var.base_cls_vt is not None cg(var.base_cls_vt) # type: ignore[attr-defined] cg.extend_output([cg.create_load_attr("__new__")]) @@ -978,7 +980,9 @@ class SideEffects: elif self.is_attribute_mutation(var): if isinstance( - var, variables.UserDefinedDictVariable + var, + variables.UserDefinedDictVariable, + # pyrefly: ignore # bad-argument-type ) and self.is_modified(var._dict_vt): # Do dict related update manually here. The store_attr # mutations will be applied later. @@ -1011,6 +1015,7 @@ class SideEffects: ] ) + # pyrefly: ignore # bad-argument-type cg(var._dict_vt, allow_cache=False) # Don't codegen via source cg.extend_output( [ @@ -1031,7 +1036,9 @@ class SideEffects: ] ) elif isinstance( - var, variables.UserDefinedListVariable + var, + variables.UserDefinedListVariable, + # pyrefly: ignore # bad-argument-type ) and self.is_modified(var._list_vt): # Update the list to the updated items. Be careful in # calling the list methods and not the overridden methods. @@ -1048,6 +1055,7 @@ class SideEffects: ] ) + # pyrefly: ignore # bad-argument-type cg(var._list_vt, allow_cache=False) # Don't codegen via source cg.extend_output( [ diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 3e2003062b03..3e8f8053fe6a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -563,6 +563,7 @@ def log_graph_break( ) else: user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment] + # pyrefly: ignore # bad-argument-type user_stack = collapse_resume_frames(user_stack) user_stack_formatted = "".join(traceback.format_list(user_stack)) user_stack_trace = ( @@ -1040,6 +1041,7 @@ class BytecodeDispatchTableMeta(type): op: getattr(cls, opname, functools.partial(_missing, opname)) for opname, op in dis.opmap.items() } + # pyrefly: ignore # missing-attribute cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)] @@ -1788,13 +1790,17 @@ class InstructionTranslatorBase( source = self.import_source(module_name) if self.exec_recorder: + # pyrefly: ignore # unbound-name self.exec_recorder.add_local_mod(recorded_name, value) + # pyrefly: ignore # unbound-name if istype(value, (types.ModuleType, DummyModule)): + # pyrefly: ignore # unbound-name self.push(PythonModuleVariable(value, source=source)) else: unimplemented_v2( gb_type="Bad import result", + # pyrefly: ignore # unbound-name context=typestr(value), explanation="Import result is not a Python module.", hints=[], @@ -1873,6 +1879,7 @@ class InstructionTranslatorBase( exit, exc = self.popn(2) assert exc is None self.push(exc) + # pyrefly: ignore # bad-argument-type self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {})) def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None: @@ -2294,7 +2301,9 @@ class InstructionTranslatorBase( ): return True elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass( - exc_instance.fn, expected_type.fn + exc_instance.fn, + # pyrefly: ignore # missing-attribute + expected_type.fn, ): return True @@ -2354,26 +2363,37 @@ class InstructionTranslatorBase( assert isinstance(null, NullVariable) if not isinstance( - argsvars, BaseListVariable + # pyrefly: ignore # unbound-name + argsvars, + BaseListVariable, + # pyrefly: ignore # unbound-name ) and argsvars.has_force_unpack_var_sequence(self): + # pyrefly: ignore # unbound-name argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self)) # Unpack for cases like fn(**obj) where obj is a map + # pyrefly: ignore # unbound-name if isinstance(kwargsvars, UserDefinedObjectVariable): kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type] + # pyrefly: ignore # unbound-name if not isinstance(argsvars, BaseListVariable) or not isinstance( - kwargsvars, ConstDictVariable + # pyrefly: ignore # unbound-name + kwargsvars, + ConstDictVariable, ): unimplemented_v2( gb_type="Variadic function call with bad args/kwargs type", + # pyrefly: ignore # unbound-name context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}", explanation="Expected args to be a list and kwargs to be a dict", hints=[*graph_break_hints.USER_ERROR], ) # Map to a dictionary of str -> VariableTracker + # pyrefly: ignore # unbound-name, missing-attribute kwargsvars = kwargsvars.keys_as_python_constant() + # pyrefly: ignore # unbound-name, missing-attribute self.call_function(fn, argsvars.items, kwargsvars) @break_graph_if_unsupported(push=1) @@ -2437,6 +2457,7 @@ class InstructionTranslatorBase( def LOAD_ATTR(self, inst: Instruction) -> None: if sys.version_info >= (3, 12): + # pyrefly: ignore # unsupported-operation if inst.arg % 2: self.LOAD_METHOD(inst) return @@ -3029,14 +3050,17 @@ class InstructionTranslatorBase( "(i.e. `a, b, c = d`).", hints=[*graph_break_hints.USER_ERROR], ) + # pyrefly: ignore # unbound-name if len(val) != inst.argval: unimplemented_v2( gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE", + # pyrefly: ignore # unbound-name context=f"expected length: {inst.argval}, actual: {len(val)}", explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode " "(i.e. `a, b, c = d`) with unexpected length.", hints=[*graph_break_hints.DYNAMO_BUG], ) + # pyrefly: ignore # unbound-name for i in reversed(val): self.push(i) @@ -3409,9 +3433,13 @@ class InstructionTranslatorBase( args = [contents[1]] if kw_names: + # pyrefly: ignore # bad-argument-type args = args + contents[2 : -len(kw_names)] + # pyrefly: ignore # bad-argument-type kwargs_list = contents[-len(kw_names) :] + # pyrefly: ignore # no-matching-overload kwargs = dict(zip(kw_names, kwargs_list)) + # pyrefly: ignore # bad-argument-type assert len(kwargs) == len(kw_names) else: args = args + contents[2:] @@ -4118,6 +4146,7 @@ class InstructionTranslator(InstructionTranslatorBase): and isinstance(tos, LocalGeneratorObjectVariable) ): self.stack[-1] = ListIteratorVariable( + # pyrefly: ignore # unbound-name tos.force_unpack_var_sequence(self), mutation_type=ValueMutationNew(), ) @@ -4188,6 +4217,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): """Trace and inline a called method""" symbolic_result: Optional[VariableTracker] + # pyrefly: ignore # bad-override parent: InstructionTranslatorBase @classmethod @@ -4231,6 +4261,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): # trace through. if ( hasattr(getattr(func, "fn", None), "_origin") + # pyrefly: ignore # missing-attribute and func.fn._origin is produce_trampoline_autograd_apply ): # Known sound @@ -4305,12 +4336,14 @@ class InliningInstructionTranslator(InstructionTranslatorBase): tracing_ctx.previously_inlined_functions[code] = result try: + # pyrefly: ignore # missing-attribute sub_locals = func.bind_args(parent, args, kwargs) except TypeError as e: # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info raise ArgsMismatchError( # noqa: B904 "{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format( reason=str(e), + # pyrefly: ignore # missing-attribute func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}", args=[arg.python_type() for arg in args], kwargs=kwargs, @@ -4394,6 +4427,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): sub_locals, parent.symbolic_globals, parent.symbolic_torch_function_state, + # pyrefly: ignore # bad-argument-type func, ) return tracer diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index 77860c720a6e..41ceb9ecbf41 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -153,7 +153,9 @@ class CPythonTestCase(TestCase): assertTupleEqual = unittest.TestCase.assertTupleEqual assertSetEqual = unittest.TestCase.assertSetEqual assertDictEqual = polyfills.assert_dict_equal + # pyrefly: ignore # bad-override assertRaises = unittest.TestCase.assertRaises + # pyrefly: ignore # bad-override assertRaisesRegex = unittest.TestCase.assertRaisesRegex assertWarns = unittest.TestCase.assertWarns assertWarnsRegex = unittest.TestCase.assertWarnsRegex @@ -169,8 +171,10 @@ class CPythonTestCase(TestCase): ) -> Callable[..., Any]: # We want to compile only the test function, excluding any setup code # from unittest + method = getattr(self, self._testMethodName) method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method) + setattr(self, self._testMethodName, method) return fn diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index f48dae1d0e33..07c0c172342e 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -207,6 +207,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_ launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py")) with open(launch_file) as f: launch_code = f.read() + self.assertTrue(os.path.exists(launch_file)) args = ["python3", launch_file, "minify", *minifier_args] @@ -218,6 +219,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_ print("minifier stdout:", launch_proc.stdout.decode("utf-8")) stderr = launch_proc.stderr.decode("utf-8") print("minifier stderr:", stderr) + self.assertNotIn("Input graph did not fail the tester", stderr) return launch_proc, launch_code @@ -230,6 +232,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_ repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py")) with open(repro_file) as f: repro_code = f.read() + self.assertTrue(os.path.exists(repro_file)) repro_proc = self._maybe_subprocess_run( @@ -296,11 +299,14 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}" if expected_error is None: # Just check that there was no error self.assertEqual(test_proc.returncode, 0) + self.assertIsNone(repro_dir) return None # NB: Intentionally do not test return code; we only care about # actually generating the repro, we don't have to crash + self.assertIn(expected_error, test_proc.stderr.decode("utf-8")) + self.assertIsNotNone(repro_dir) print("running minifier", file=sys.stderr) _minifier_proc, minifier_code = self._run_minifier_launcher( @@ -311,6 +317,7 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}" ) print("running repro", file=sys.stderr) repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate) + self.assertIn(expected_error, repro_proc.stderr.decode("utf-8")) self.assertNotEqual(repro_proc.returncode, 0) return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code) diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 1ce88f1d744c..2f6e034e9a9d 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -496,6 +496,7 @@ def make_test_cls_with_patches( def skipIfNotPy311(fn: Callable[_P, _T]) -> Callable[_P, _T]: if sys.version_info >= (3, 11): return fn + # pyrefly: ignore # bad-return, bad-argument-type return unittest.skip(fn) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index d72a8b0ce7be..cd411a7e1117 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3005,6 +3005,7 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: obj = torch_dir + k[len("torch/") :] if obj is not None: if is_annotate_wrapped_function(obj): + # pyrefly: ignore # missing-attribute obj = obj.__wrapped__ if is_lru_cache_wrapped_function(obj): obj = obj.__wrapped__ diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index b8bb7e43a247..b6d51f70a6e4 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -295,11 +295,13 @@ def increment_op_count(cnt: int) -> None: def calculate_time_spent() -> dict[str, float]: total_by_key = {} for phase, timing in cumulative_time_spent_ns.items(): + # pyrefly: ignore # unsupported-operation total_by_key[phase] = timing / 1e9 total_by_key["total_wall_time"] = total_by_key.get( "entire_frame_compile", 0 ) + total_by_key.get("entire_backward_compile", 0) + # pyrefly: ignore # bad-return return total_by_key @@ -798,6 +800,7 @@ def compile_times(repr: Literal["str"], aggregate: bool = False) -> str: ... @overload +# pyrefly: ignore # inconsistent-overload def compile_times( repr: Literal["csv"], aggregate: bool = False ) -> tuple[list[str], list[object]]: ... @@ -1463,6 +1466,7 @@ class CompilationMetrics: compile_id = all_metrics.get("compile_id") all_metrics["compile_id"] = str(compile_id) if compile_id else None + # pyrefly: ignore # bad-argument-type return cls(**all_metrics) @@ -2253,6 +2257,7 @@ def is_jit_model( Union[ torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule, + # pyrefly: ignore # invalid-param-spec torch.jit.ScriptFunction[Any, Any], torch.jit.ScriptModule, ] @@ -2361,6 +2366,7 @@ def checkpoint_params(gm: torch.fx.GraphModule) -> Callable[[], None]: cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) saved_state = [ (param, param._version, torch.clone(param)) + # pyrefly: ignore # bad-argument-type for param in itertools.chain(gm.parameters(), gm.buffers()) ] @@ -2626,13 +2632,16 @@ def get_items_from_dict(obj: dict[K, V]) -> Iterable[tuple[K, Union[V, Any]]]: if istype(obj, (dict, OrderedDict)): return obj.items() elif isinstance(obj, OrderedDict): + # pyrefly: ignore # bad-argument-type return [(k, OrderedDict.__getitem__(obj, k)) for k in OrderedDict.keys(obj)] else: + # pyrefly: ignore # bad-argument-type return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)] def nn_module_new(cls: Any) -> Any: obj = object_new(cls) + # pyrefly: ignore # bad-argument-type torch.nn.Module.__init__(obj) return obj @@ -2679,6 +2688,7 @@ def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any: dict_class = dict if isinstance(d, OrderedDict): dict_class = OrderedDict + # pyrefly: ignore # bad-argument-type return next(itertools.islice(dict_class.keys(d), n, n + 1)) @@ -3222,8 +3232,10 @@ def format_func_info(code: CodeType) -> str: @contextlib.contextmanager def disable_cache_limit() -> Generator[None, None, None]: prior = config.recompile_limit + # pyrefly: ignore # bad-assignment config.recompile_limit = sys.maxsize prior_acc_limit = config.accumulated_recompile_limit + # pyrefly: ignore # bad-assignment config.accumulated_recompile_limit = sys.maxsize try: @@ -3958,6 +3970,7 @@ class numpy_operator_wrapper(Generic[_P, R]): def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any: assert not kwargs + # pyrefly: ignore # bad-assignment args = ( tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args ) @@ -4157,6 +4170,7 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: # (x) + (y) # ~~^~~~~~~ while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#": + # pyrefly: ignore # unbound-name if ch in "\\#": cur_lineno, cur_col = nextline(cur_lineno, cur_col) else: @@ -4507,6 +4521,7 @@ class GmWrapper(torch.nn.Module): self.unflatten_fn = unflatten_fn def forward(self, *args: Any) -> Any: + # pyrefly: ignore # annotation-mismatch args: list[Any] = list(args) return self.gm(*self.unflatten_fn(args)) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 282b448d55ab..048d505636b8 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1028,6 +1028,7 @@ class BuiltinVariable(VariableTracker): def call_self_handler(tx: "InstructionTranslator", args, kwargs): try: + # pyrefly: ignore # not-callable result = self_handler(tx, *args, **kwargs) if result is not None: return result @@ -1035,6 +1036,7 @@ class BuiltinVariable(VariableTracker): # Check if binding is bad. inspect signature bind is expensive. # So check only when handler call fails. try: + # pyrefly: ignore # bad-argument-type inspect.signature(self_handler).bind(tx, *args, **kwargs) except TypeError as e: has_constant_handler = obj.has_constant_handler(args, kwargs) @@ -1087,6 +1089,7 @@ class BuiltinVariable(VariableTracker): hints=[*graph_break_hints.DYNAMO_BUG], from_exc=exc, ) + # pyrefly: ignore # unbound-name return VariableTracker.build(tx, res) else: @@ -1115,6 +1118,7 @@ class BuiltinVariable(VariableTracker): tx, args=list(map(ConstantVariable.create, exc.args)), ) + # pyrefly: ignore # unbound-name return VariableTracker.build(tx, res) handlers.append(constant_fold_handler) @@ -1437,6 +1441,7 @@ class BuiltinVariable(VariableTracker): resolved_fn = getattr(self.fn, name) if resolved_fn in dict_methods: if isinstance(args[0], variables.UserDefinedDictVariable): + # pyrefly: ignore # missing-attribute return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs) elif isinstance(args[0], variables.ConstDictVariable): return args[0].call_method(tx, name, args[1:], kwargs) @@ -1445,6 +1450,7 @@ class BuiltinVariable(VariableTracker): resolved_fn = getattr(self.fn, name) if resolved_fn in set_methods: if isinstance(args[0], variables.UserDefinedSetVariable): + # pyrefly: ignore # missing-attribute return args[0]._set_vt.call_method(tx, name, args[1:], kwargs) elif isinstance(args[0], variables.SetVariable): return args[0].call_method(tx, name, args[1:], kwargs) @@ -1533,10 +1539,12 @@ class BuiltinVariable(VariableTracker): if type(arg.value).__str__ is object.__str__: # Rely on the object str method try: + # pyrefly: ignore # unbound-name return variables.ConstantVariable.create(value=str_method()) except AttributeError: # Graph break return + # pyrefly: ignore # unbound-name elif is_wrapper_or_member_descriptor(str_method): unimplemented_v2( gb_type="Attempted to a str() method implemented in C/C++", @@ -1653,8 +1661,10 @@ class BuiltinVariable(VariableTracker): else: raw_b = b.raw_value if self.fn is max: + # pyrefly: ignore # missing-attribute raw_res = max(a.raw_value, raw_b) else: + # pyrefly: ignore # missing-attribute raw_res = min(a.raw_value, raw_b) need_unwrap = any( @@ -2106,6 +2116,7 @@ class BuiltinVariable(VariableTracker): ) if isinstance(arg, variables.UserDefinedExceptionClassVariable): + # pyrefly: ignore # unbound-name return ConstantVariable.create(isinstance(arg_type, isinstance_type)) isinstance_type_tuple: tuple[type, ...] @@ -2138,8 +2149,10 @@ class BuiltinVariable(VariableTracker): # through it. This is a limitation of the current implementation. # Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it # might not be a big issue and we trade off it for performance. + # pyrefly: ignore # unbound-name val = issubclass(arg_type, isinstance_type_tuple) except TypeError: + # pyrefly: ignore # unbound-name val = arg_type in isinstance_type_tuple return variables.ConstantVariable.create(val) @@ -2161,6 +2174,7 @@ class BuiltinVariable(VariableTracker): # WARNING: This might run arbitrary user code `__subclasscheck__`. # See the comment in call_isinstance above. + # pyrefly: ignore # unbound-name return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py)) def call_super(self, tx: "InstructionTranslator", a, b): @@ -2206,7 +2220,9 @@ class BuiltinVariable(VariableTracker): value = getattr(self.fn, name) except AttributeError: raise_observed_exception(AttributeError, tx) + # pyrefly: ignore # unbound-name if not callable(value): + # pyrefly: ignore # unbound-name return VariableTracker.build(tx, value, source) return variables.GetAttrVariable(self, name, source=source) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 2858e2af9252..d8800f0fa74f 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -651,6 +651,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): def handle_use_deterministic_algorithms( self, tx: "InstructionTranslator", mode, warn_only=False ): + # pyrefly: ignore # missing-attribute if warn_only and warn_only.as_python_constant(): unimplemented_v2( gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)", @@ -1035,6 +1036,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): else: raise torch._dynamo.exc.Unsupported("branch not supported") return variables.ConstantVariable.create( + # pyrefly: ignore # bad-argument-type torch.fx.experimental.symbolic_shapes.guard_scalar(val) ) @@ -1081,6 +1083,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): return return variables.ConstantVariable.create( + # pyrefly: ignore # bad-argument-type torch.fx.experimental.symbolic_shapes.has_static_value(val) ) @@ -1212,13 +1215,17 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): ) # need to guard only on no-arg get_device_module + # pyrefly: ignore # unbound-name if device is None: source = CallFunctionNoArgsSource(self.source) install_guard(source.make_guard(GuardBuilder.ID_MATCH)) # assumes `module` is in the form `torch.xyz` new_source = AttrSource( - TorchSource(), module.__name__.rsplit(".", maxsplit=1)[-1] + TorchSource(), + # pyrefly: ignore # unbound-name + module.__name__.rsplit(".", maxsplit=1)[-1], ) + # pyrefly: ignore # unbound-name return VariableTracker.build(tx, module, new_source) @register(torch.set_default_device) @@ -1373,9 +1380,12 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): f"{fn.__name__}_spec", f_spec ) input_spec_proxy = tx.output.register_static_attr_and_return_proxy( - fn.__name__ + "_input_spec", input_spec + fn.__name__ + "_input_spec", + # pyrefly: ignore # unbound-name + input_spec, ) f_spec_proxy.node.type = type(f_spec) + # pyrefly: ignore # unbound-name input_spec_proxy.node.type = type(input_spec) all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args) @@ -1716,6 +1726,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th ) # this results in cleaner graphs, but only works for inputs + # pyrefly: ignore # missing-attribute if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) @@ -1734,7 +1745,9 @@ For now, dynamo will explicitly graph break when it encounters user code with th # TODO[@lucaskabela]: Remove the behavior below since it is deprecated if isinstance( - data, TensorWithTFOverrideVariable + data, + TensorWithTFOverrideVariable, + # pyrefly: ignore # missing-attribute ) or is_traceable_wrapper_subclass_type(data.class_type): unimplemented_v2( gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass", @@ -1757,8 +1770,11 @@ For now, dynamo will explicitly graph break when it encounters user code with th ) try: + # pyrefly: ignore # missing-attribute shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) + # pyrefly: ignore # missing-attribute dtype = data.var_getattr(tx, "dtype").as_python_constant() + # pyrefly: ignore # missing-attribute device = data.var_getattr(tx, "device").as_python_constant() except NotImplementedError as e: unimplemented_v2( @@ -1773,9 +1789,13 @@ For now, dynamo will explicitly graph break when it encounters user code with th ) placeholder = tx.output.synthetic_graph_input( - new_parameter_placeholder, [shape, dtype, device, requires_grad] + new_parameter_placeholder, + # pyrefly: ignore # unbound-name + [shape, dtype, device, requires_grad], ) + # pyrefly: ignore # missing-attribute if data.requires_grad: + # pyrefly: ignore # missing-attribute data = data.call_method(tx, "detach", [], {}) from .builder import wrap_fx_proxy @@ -1785,6 +1805,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th tx.output.create_proxy( "call_function", tracable_create_parameter, + # pyrefly: ignore # missing-attribute (data.as_proxy(), placeholder.as_proxy()), {}, ), diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index b99d2667c3a9..416619cee029 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -646,7 +646,6 @@ def update_schema(): assert thrift_content[1].startswith("// checksum<<") thrift_checksum_real = _hash_content("\n".join(thrift_content[2:])) - # pyrefly: ignore # import-error from yaml import load, Loader dst = load(content, Loader=Loader) diff --git a/torch/_export/utils.py b/torch/_export/utils.py index a55385425373..dfe3c8a09da2 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -34,7 +34,7 @@ from torch.fx._pytree import ( _deregister_pytree_flatten_spec, register_pytree_flatten_spec, ) -from torch.utils._pytree import ( # pyrefly: ignore # deprecated +from torch.utils._pytree import ( _deregister_pytree_node, _register_pytree_node, Context, diff --git a/torch/_library/fake_profile.py b/torch/_library/fake_profile.py index 9a835dcd1dba..9e0b8cccdb56 100644 --- a/torch/_library/fake_profile.py +++ b/torch/_library/fake_profile.py @@ -198,7 +198,7 @@ def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str: to a file. The yaml string can be loaded back into an operator profile structure using `read_profiles_from_yaml`. """ - # pyrefly: ignore # import-error + import yaml from torch._export.serde.serialize import ( @@ -263,7 +263,7 @@ def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]: """ Reads the yaml saved by `save_op_profiles` and returns the operator profiles. """ - # pyrefly: ignore # import-error + import yaml from torch._export.serde.serialize import ( diff --git a/torch/fx/experimental/unification/multipledispatch/__init__.py b/torch/fx/experimental/unification/multipledispatch/__init__.py index b7d633ac1cee..bb7304069243 100644 --- a/torch/fx/experimental/unification/multipledispatch/__init__.py +++ b/torch/fx/experimental/unification/multipledispatch/__init__.py @@ -1,5 +1,5 @@ from .core import dispatch -from .dispatcher import ( # pyrefly: ignore # deprecated +from .dispatcher import ( Dispatcher, halt_ordering, MDNotImplementedError, diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 3d002b7b2365..fceec1272c16 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -153,6 +153,7 @@ class CausalBias(torch.Tensor): diagonal=diagonal_offset, ) + # pyrefly: ignore # bad-return def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor: """ Materializes the causal bias into a tensor form. diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index bcef3cc8adaa..dd6152ccdc6f 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -84,6 +84,7 @@ _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] +# pyrefly: ignore # invalid-inheritance class FlexKernelOptions(TypedDict, total=False): """Options for controlling the behavior of FlexAttention kernels. @@ -127,76 +128,93 @@ class FlexKernelOptions(TypedDict, total=False): """ # Performance tuning options + # pyrefly: ignore # invalid-annotation num_warps: NotRequired[int] """Number of warps to use in the CUDA kernel. Higher values may improve performance but increase register pressure. Default is determined by autotuning.""" + # pyrefly: ignore # invalid-annotation num_stages: NotRequired[int] """Number of pipeline stages in the CUDA kernel. Higher values may improve performance but increase shared memory usage. Default is determined by autotuning.""" + # pyrefly: ignore # invalid-annotation BLOCK_M: NotRequired[int] """Thread block size for the sequence length dimension of Q in forward pass. Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" + # pyrefly: ignore # invalid-annotation BLOCK_N: NotRequired[int] """Thread block size for the sequence length dimension of K/V in forward pass. Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" # Backward-specific block sizes (when prefixed with 'bwd_') + # pyrefly: ignore # invalid-annotation BLOCK_M1: NotRequired[int] """Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'. Default is determined by autotuning.""" + # pyrefly: ignore # invalid-annotation BLOCK_N1: NotRequired[int] """Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'. Default is determined by autotuning.""" + # pyrefly: ignore # invalid-annotation BLOCK_M2: NotRequired[int] """Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'. Default is determined by autotuning.""" + # pyrefly: ignore # invalid-annotation BLOCK_N2: NotRequired[int] """Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'. Default is determined by autotuning.""" + # pyrefly: ignore # invalid-annotation PRESCALE_QK: NotRequired[bool] """Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but may have more numerical error. Default: False.""" + # pyrefly: ignore # invalid-annotation ROWS_GUARANTEED_SAFE: NotRequired[bool] """If True, guarantees that at least one value in each row is not masked out. Allows skipping safety checks for better performance. Only set this if you are certain your mask guarantees this property. For example, causal attention is guaranteed safe because each query has at least 1 key-value to attend to. Default: False.""" + # pyrefly: ignore # invalid-annotation BLOCKS_ARE_CONTIGUOUS: NotRequired[bool] """If True, guarantees that all blocks in the mask are contiguous. Allows optimizing block traversal. For example, causal masks would satisfy this, but prefix_lm + sliding window would not. Default: False.""" + # pyrefly: ignore # invalid-annotation WRITE_DQ: NotRequired[bool] """Controls whether gradient scatters are done in the DQ iteration loop of the backward pass. Setting this to False will force this to happen in the DK loop which depending on your specific score_mod and mask_mod might be faster. Default: True.""" + # pyrefly: ignore # invalid-annotation FORCE_USE_FLEX_ATTENTION: NotRequired[bool] """If True, forces the use of the flex attention kernel instead of potentially using the more optimized flex-decoding kernel for short sequences. This can be a helpful option for debugging. Default: False.""" + # pyrefly: ignore # invalid-annotation USE_TMA: NotRequired[bool] """Whether to use Tensor Memory Accelerator (TMA) on supported hardware. This is experimental and may not work on all hardware, currently specific to NVIDIA GPUs Hopper+. Default: False.""" # ROCm-specific options + # pyrefly: ignore # invalid-annotation kpack: NotRequired[int] """ROCm-specific kernel packing parameter.""" + # pyrefly: ignore # invalid-annotation matrix_instr_nonkdim: NotRequired[int] """ROCm-specific matrix instruction non-K dimension.""" + # pyrefly: ignore # invalid-annotation waves_per_eu: NotRequired[int] """ROCm-specific waves per execution unit.""" @@ -581,6 +599,7 @@ class BlockMask: block_size = (self.BLOCK_SIZE,) # type: ignore[assignment] seq_lengths = (self.seq_lengths,) # type: ignore[assignment] + # pyrefly: ignore # not-iterable return ( *seq_lengths, self.kv_num_blocks, @@ -753,6 +772,7 @@ class BlockMask: partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices) if self.full_kv_num_blocks is not None: assert self.full_kv_indices is not None + # pyrefly: ignore # bad-return return partial_dense | _ordered_to_dense( self.full_kv_num_blocks, self.full_kv_indices ) diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py index 98a61bfb7c42..5d01f7f16a4a 100644 --- a/torch/nn/cpp.py +++ b/torch/nn/cpp.py @@ -78,6 +78,7 @@ class ModuleWrapper(nn.Module): # nn.Module defines training as a boolean @property # type: ignore[override] + # pyrefly: ignore # bad-override def training(self): return self.cpp_module.training diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 65c553df7e6e..57f63370aa65 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1266,6 +1266,7 @@ def adaptive_max_pool2d_with_indices( output_size, return_indices=return_indices, ) + # pyrefly: ignore # bad-argument-type output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool2d(input, output_size) @@ -1323,6 +1324,7 @@ def adaptive_max_pool3d_with_indices( output_size, return_indices=return_indices, ) + # pyrefly: ignore # bad-argument-type output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool3d(input, output_size) @@ -1381,6 +1383,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> T """ if has_torch_function_unary(input): return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) + # pyrefly: ignore # bad-argument-type _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool2d(input, _output_size) @@ -1396,6 +1399,7 @@ def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> T """ if has_torch_function_unary(input): return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) + # pyrefly: ignore # bad-argument-type _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool3d(input, _output_size) @@ -2431,6 +2435,7 @@ def _no_grad_embedding_renorm_( input: Tensor, max_norm: float, norm_type: float, + # pyrefly: ignore # bad-return ) -> tuple[Tensor, Tensor]: torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type) @@ -2684,6 +2689,7 @@ def embedding_bag( if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested: include_last_offset = True + # pyrefly: ignore # missing-attribute offsets = input.offsets() input = input.values().reshape(-1) if per_sample_weights is not None: @@ -2818,6 +2824,7 @@ def batch_norm( eps=eps, ) if training: + # pyrefly: ignore # bad-argument-type _verify_batch_size(input.size()) return torch.batch_norm( @@ -2873,6 +2880,7 @@ def instance_norm( eps=eps, ) if use_input_stats: + # pyrefly: ignore # bad-argument-type _verify_spatial_size(input.size()) return torch.instance_norm( input, @@ -2998,11 +3006,13 @@ def local_response_norm( div = input.mul(input) if dim == 3: div = div.unsqueeze(1) + # pyrefly: ignore # bad-argument-type div = pad(div, (0, 0, size // 2, (size - 1) // 2)) div = avg_pool2d(div, (size, 1), stride=1).squeeze(1) else: sizes = input.size() div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) + # pyrefly: ignore # bad-argument-type div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) div = div.view(sizes) @@ -3151,7 +3161,12 @@ def nll_loss( if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) return torch._C._nn.nll_loss_nd( - input, target, weight, _Reduction.get_enum(reduction), ignore_index + input, + target, + weight, + # pyrefly: ignore # bad-argument-type + _Reduction.get_enum(reduction), + ignore_index, ) @@ -3296,6 +3311,7 @@ def gaussian_nll_loss( var.clamp_(min=eps) # Calculate the loss + # pyrefly: ignore # unsupported-operation loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var) if full: loss += 0.5 * math.log(2 * math.pi) @@ -3471,6 +3487,7 @@ def cross_entropy( input, target, weight, + # pyrefly: ignore # bad-argument-type _Reduction.get_enum(reduction), ignore_index, label_smoothing, @@ -3535,6 +3552,7 @@ def binary_cross_entropy( new_size = _infer_size(target.size(), weight.size()) weight = weight.expand(new_size) + # pyrefly: ignore # bad-argument-type return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) @@ -3663,11 +3681,18 @@ def smooth_l1_loss( if beta == 0.0: return torch._C._nn.l1_loss( - expanded_input, expanded_target, _Reduction.get_enum(reduction) + expanded_input, + expanded_target, + # pyrefly: ignore # bad-argument-type + _Reduction.get_enum(reduction), ) else: return torch._C._nn.smooth_l1_loss( - expanded_input, expanded_target, _Reduction.get_enum(reduction), beta + expanded_input, + expanded_target, + # pyrefly: ignore # bad-argument-type + _Reduction.get_enum(reduction), + beta, ) @@ -3725,7 +3750,11 @@ def huber_loss( if weight is None: # Use the optimized C++ backend for standard Huber loss return torch._C._nn.huber_loss( - expanded_input, expanded_target, _Reduction.get_enum(reduction), delta + expanded_input, + expanded_target, + # pyrefly: ignore # bad-argument-type + _Reduction.get_enum(reduction), + delta, ) else: if weight.size() != input.size(): @@ -3733,7 +3762,11 @@ def huber_loss( # Calculate the unweighted loss first unweighted_loss = torch._C._nn.huber_loss( - expanded_input, expanded_target, _Reduction.get_enum("none"), delta + expanded_input, + expanded_target, + # pyrefly: ignore # bad-argument-type + _Reduction.get_enum("none"), + delta, ) # Apply weight to the unweighted loss @@ -3820,7 +3853,10 @@ def l1_loss( ) else: return torch._C._nn.l1_loss( - expanded_input, expanded_target, _Reduction.get_enum(reduction) + expanded_input, + expanded_target, + # pyrefly: ignore # bad-argument-type + _Reduction.get_enum(reduction), ) @@ -3895,7 +3931,10 @@ def mse_loss( ) else: return torch._C._nn.mse_loss( - expanded_input, expanded_target, _Reduction.get_enum(reduction) + expanded_input, + expanded_target, + # pyrefly: ignore # bad-argument-type + _Reduction.get_enum(reduction), ) @@ -4032,6 +4071,7 @@ def multilabel_margin_loss( reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) + # pyrefly: ignore # bad-argument-type return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) @@ -4073,6 +4113,7 @@ def soft_margin_loss( reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) + # pyrefly: ignore # bad-argument-type return torch._C._nn.soft_margin_loss(input, target, reduction_enum) @@ -4237,7 +4278,13 @@ def multi_margin_loss( raise ValueError("weight must be one-dimensional") return torch._C._nn.multi_margin_loss( - input, target, p, margin, weight, reduction_enum + input, + target, + p, + margin, + weight, + # pyrefly: ignore # bad-argument-type + reduction_enum, ) @@ -4383,6 +4430,7 @@ def upsample( # noqa: F811 scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None, + # pyrefly: ignore # bad-return ) -> Tensor: # noqa: B950 pass @@ -4394,6 +4442,7 @@ def upsample( # noqa: F811 scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None, + # pyrefly: ignore # bad-return ) -> Tensor: # noqa: B950 pass @@ -4496,6 +4545,7 @@ def interpolate( # noqa: F811 align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, + # pyrefly: ignore # bad-return ) -> Tensor: # noqa: B950 pass @@ -4509,6 +4559,7 @@ def interpolate( # noqa: F811 align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, + # pyrefly: ignore # bad-return ) -> Tensor: # noqa: B950 pass @@ -4522,6 +4573,7 @@ def interpolate( # noqa: F811 align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, + # pyrefly: ignore # bad-return ) -> Tensor: # noqa: B950 pass @@ -4535,6 +4587,7 @@ def interpolate( # noqa: F811 align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, + # pyrefly: ignore # bad-return ) -> Tensor: pass @@ -4709,6 +4762,7 @@ def interpolate( # noqa: F811 ( torch.floor( ( + # pyrefly: ignore # missing-attribute input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32) ).float() @@ -4733,21 +4787,28 @@ def interpolate( # noqa: F811 ) if input.dim() == 3 and mode == "nearest": + # pyrefly: ignore # bad-argument-type return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) if input.dim() == 4 and mode == "nearest": + # pyrefly: ignore # bad-argument-type return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) if input.dim() == 5 and mode == "nearest": + # pyrefly: ignore # bad-argument-type return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) if input.dim() == 3 and mode == "nearest-exact": + # pyrefly: ignore # bad-argument-type return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) if input.dim() == 4 and mode == "nearest-exact": + # pyrefly: ignore # bad-argument-type return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) if input.dim() == 5 and mode == "nearest-exact": + # pyrefly: ignore # bad-argument-type return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) if input.dim() == 3 and mode == "area": assert output_size is not None + # pyrefly: ignore # bad-argument-type return adaptive_avg_pool1d(input, output_size) if input.dim() == 4 and mode == "area": assert output_size is not None @@ -4759,13 +4820,21 @@ def interpolate( # noqa: F811 if input.dim() == 3 and mode == "linear": assert align_corners is not None return torch._C._nn.upsample_linear1d( - input, output_size, align_corners, scale_factors + input, + # pyrefly: ignore # bad-argument-type + output_size, + align_corners, + scale_factors, ) if input.dim() == 4 and mode == "bilinear": assert align_corners is not None if antialias: return torch._C._nn._upsample_bilinear2d_aa( - input, output_size, align_corners, scale_factors + input, + # pyrefly: ignore # bad-argument-type + output_size, + align_corners, + scale_factors, ) # Two levels are necessary to prevent TorchScript from touching # are_deterministic_algorithms_enabled. @@ -4778,7 +4847,11 @@ def interpolate( # noqa: F811 "torch._decomp.decompositions" )._upsample_linear_vec(input, output_size, align_corners, scale_factors) return torch._C._nn.upsample_bilinear2d( - input, output_size, align_corners, scale_factors + input, + # pyrefly: ignore # bad-argument-type + output_size, + align_corners, + scale_factors, ) if input.dim() == 5 and mode == "trilinear": assert align_corners is not None @@ -4793,16 +4866,28 @@ def interpolate( # noqa: F811 "torch._decomp.decompositions" )._upsample_linear_vec(input, output_size, align_corners, scale_factors) return torch._C._nn.upsample_trilinear3d( - input, output_size, align_corners, scale_factors + input, + # pyrefly: ignore # bad-argument-type + output_size, + align_corners, + scale_factors, ) if input.dim() == 4 and mode == "bicubic": assert align_corners is not None if antialias: return torch._C._nn._upsample_bicubic2d_aa( - input, output_size, align_corners, scale_factors + input, + # pyrefly: ignore # bad-argument-type + output_size, + align_corners, + scale_factors, ) return torch._C._nn.upsample_bicubic2d( - input, output_size, align_corners, scale_factors + input, + # pyrefly: ignore # bad-argument-type + output_size, + align_corners, + scale_factors, ) if input.dim() == 3 and mode == "bilinear": @@ -4834,6 +4919,7 @@ def upsample_nearest( # noqa: F811 input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, + # pyrefly: ignore # bad-return ) -> Tensor: pass @@ -4843,6 +4929,7 @@ def upsample_nearest( # noqa: F811 input: Tensor, size: Optional[list[int]] = None, scale_factor: Optional[float] = None, + # pyrefly: ignore # bad-return ) -> Tensor: pass @@ -4884,6 +4971,7 @@ def upsample_bilinear( # noqa: F811 input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, + # pyrefly: ignore # bad-return ) -> Tensor: pass @@ -4893,6 +4981,7 @@ def upsample_bilinear( # noqa: F811 input: Tensor, size: Optional[list[int]] = None, scale_factor: Optional[float] = None, + # pyrefly: ignore # bad-return ) -> Tensor: pass @@ -4902,6 +4991,7 @@ def upsample_bilinear( # noqa: F811 input: Tensor, size: Optional[int] = None, scale_factor: Optional[list[float]] = None, + # pyrefly: ignore # bad-return ) -> Tensor: pass @@ -4911,6 +5001,7 @@ def upsample_bilinear( # noqa: F811 input: Tensor, size: Optional[list[int]] = None, scale_factor: Optional[list[float]] = None, + # pyrefly: ignore # bad-return ) -> Tensor: pass @@ -5717,6 +5808,7 @@ def _in_projection_packed( .squeeze(-2) .contiguous() ) + # pyrefly: ignore # bad-return return proj[0], proj[1], proj[2] else: # encoder-decoder attention @@ -5735,6 +5827,7 @@ def _in_projection_packed( .squeeze(-2) .contiguous() ) + # pyrefly: ignore # bad-return return (q_proj, kv_proj[0], kv_proj[1]) else: w_q, w_k, w_v = w.chunk(3) @@ -5742,6 +5835,7 @@ def _in_projection_packed( b_q = b_k = b_v = None else: b_q, b_k, b_v = b.chunk(3) + # pyrefly: ignore # bad-return return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) @@ -6372,8 +6466,10 @@ def multi_head_attention_forward( k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: + # pyrefly: ignore # bad-argument-type attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: + # pyrefly: ignore # bad-argument-type key_padding_mask = pad(key_padding_mask, (0, 1)) else: assert bias_k is None @@ -6382,8 +6478,10 @@ def multi_head_attention_forward( # # reshape q, k, v for multihead attention and make them batch first # + # pyrefly: ignore # no-matching-overload q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if static_k is None: + # pyrefly: ignore # no-matching-overload k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed @@ -6395,6 +6493,7 @@ def multi_head_attention_forward( ) k = static_k if static_v is None: + # pyrefly: ignore # no-matching-overload v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed @@ -6410,14 +6509,20 @@ def multi_head_attention_forward( if add_zero_attn: zero_attn_shape = (bsz * num_heads, 1, head_dim) k = torch.cat( - [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 + # pyrefly: ignore # no-matching-overload + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], + dim=1, ) v = torch.cat( - [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 + # pyrefly: ignore # no-matching-overload + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], + dim=1, ) if attn_mask is not None: + # pyrefly: ignore # bad-argument-type attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: + # pyrefly: ignore # bad-argument-type key_padding_mask = pad(key_padding_mask, (0, 1)) # update source sequence length after adjustments @@ -6467,6 +6572,7 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) attn_output = ( + # pyrefly: ignore # no-matching-overload attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) ) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) @@ -6493,13 +6599,16 @@ def multi_head_attention_forward( attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) q = q.view(bsz, num_heads, tgt_len, head_dim) + # pyrefly: ignore # no-matching-overload k = k.view(bsz, num_heads, src_len, head_dim) + # pyrefly: ignore # no-matching-overload v = v.view(bsz, num_heads, src_len, head_dim) attn_output = scaled_dot_product_attention( q, k, v, attn_mask, dropout_p, is_causal ) attn_output = ( + # pyrefly: ignore # no-matching-overload attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) ) diff --git a/torch/nn/init.py b/torch/nn/init.py index 83183d8db5f4..e033198d4e5e 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -500,6 +500,7 @@ def xavier_normal_( def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int: + # pyrefly: ignore # bad-assignment mode = mode.lower() valid_modes = ["fan_in", "fan_out"] if mode not in valid_modes: diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index dd66c2b323c8..407fcc7e279f 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -6,6 +6,7 @@ from torch.autograd.function import Function class SyncBatchNorm(Function): @staticmethod + # pyrefly: ignore # bad-override def forward( self, input, @@ -210,6 +211,7 @@ class SyncBatchNorm(Function): class CrossMapLRN2d(Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): ctx.size = size ctx.alpha = alpha @@ -265,6 +267,7 @@ class CrossMapLRN2d(Function): return output @staticmethod + # pyrefly: ignore # bad-override def backward(ctx, grad_output): input, output = ctx.saved_tensors grad_input = grad_output.new() @@ -306,6 +309,7 @@ class CrossMapLRN2d(Function): class BackwardHookFunction(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx, *args): ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) return args diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 12f20517c2ad..dbc32e0ff968 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -72,6 +72,7 @@ class _NormBase(Module): torch.tensor( 0, dtype=torch.long, + # pyrefly: ignore # bad-argument-type **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, ), ) @@ -221,6 +222,7 @@ class _LazyNormBase(LazyModuleMixin, _NormBase): dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore # bad-argument-type super().__init__( # affine and track_running_stats are hardcoded to False to # avoid creating tensors that will soon be overwritten. @@ -234,22 +236,29 @@ class _LazyNormBase(LazyModuleMixin, _NormBase): self.affine = affine self.track_running_stats = track_running_stats if self.affine: + # pyrefly: ignore # bad-argument-type self.weight = UninitializedParameter(**factory_kwargs) + # pyrefly: ignore # bad-argument-type self.bias = UninitializedParameter(**factory_kwargs) if self.track_running_stats: + # pyrefly: ignore # bad-argument-type self.running_mean = UninitializedBuffer(**factory_kwargs) + # pyrefly: ignore # bad-argument-type self.running_var = UninitializedBuffer(**factory_kwargs) self.num_batches_tracked = torch.tensor( 0, dtype=torch.long, + # pyrefly: ignore # bad-argument-type **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, ) def reset_parameters(self) -> None: + # pyrefly: ignore # bad-argument-type if not self.has_uninitialized_params() and self.num_features != 0: super().reset_parameters() def initialize_parameters(self, input) -> None: # type: ignore[override] + # pyrefly: ignore # bad-argument-type if self.has_uninitialized_params(): self.num_features = input.shape[1] if self.affine: diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index a03f57ea58a8..18e9619e4fcd 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -109,6 +109,7 @@ class Sequential(Module): def __init__(self, *args: Module) -> None: ... @overload + # pyrefly: ignore # inconsistent-overload def __init__(self, arg: OrderedDict[str, Module]) -> None: ... def __init__(self, *args): @@ -472,6 +473,7 @@ class ModuleList(Module): return self def pop(self, key: Union[int, slice]) -> Module: + # pyrefly: ignore # index-error v = self[key] del self[key] return v @@ -623,9 +625,11 @@ class ModuleDict(Module): "ModuleDict update sequence element " "#" + str(j) + " should be Iterable; is" + type(m).__name__ ) + # pyrefly: ignore # bad-argument-type if not len(m) == 2: raise ValueError( "ModuleDict update sequence element " + # pyrefly: ignore # bad-argument-type "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" ) # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] @@ -684,6 +688,7 @@ class ParameterList(Module): def __getitem__(self, idx: int) -> Any: ... @overload + # pyrefly: ignore # inconsistent-overload def __getitem__(self: T, idx: slice) -> T: ... def __getitem__(self, idx): @@ -769,9 +774,11 @@ class ParameterList(Module): size_str, device_str, ) + # pyrefly: ignore # bad-argument-type child_lines.append(" (" + str(k) + "): " + parastr) else: child_lines.append( + # pyrefly: ignore # bad-argument-type " (" + str(k) + "): Object of type: " + type(p).__name__ ) @@ -979,9 +986,11 @@ class ParameterDict(Module): "ParameterDict update sequence element " "#" + str(j) + " should be Iterable; is" + type(p).__name__ ) + # pyrefly: ignore # bad-argument-type if not len(p) == 2: raise ValueError( "ParameterDict update sequence element " + # pyrefly: ignore # bad-argument-type "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" ) # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment @@ -1002,9 +1011,11 @@ class ParameterDict(Module): size_str, device_str, ) + # pyrefly: ignore # bad-argument-type child_lines.append(" (" + str(k) + "): " + parastr) else: child_lines.append( + # pyrefly: ignore # bad-argument-type " (" + str(k) + "): Object of type: " + type(p).__name__ ) tmpstr = "\n".join(child_lines) diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 2f15c3d488f7..1fc2d63eb4f3 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -363,6 +363,7 @@ class Conv1d(_ConvNd): self.dilation, self.groups, ) + # pyrefly: ignore # no-matching-overload return F.conv1d( input, weight, bias, self.stride, self.padding, self.dilation, self.groups ) @@ -540,6 +541,7 @@ class Conv2d(_ConvNd): self.dilation, self.groups, ) + # pyrefly: ignore # no-matching-overload return F.conv2d( input, weight, bias, self.stride, self.padding, self.dilation, self.groups ) @@ -709,6 +711,7 @@ class Conv3d(_ConvNd): self.dilation, self.groups, ) + # pyrefly: ignore # no-matching-overload return F.conv3d( input, weight, bias, self.stride, self.padding, self.dilation, self.groups ) @@ -1494,6 +1497,7 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc] dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore # bad-argument-type super().__init__( 0, 0, @@ -1508,9 +1512,11 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc] padding_mode, **factory_kwargs, ) + # pyrefly: ignore # bad-override, bad-argument-type self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: + # pyrefly: ignore # bad-override, bad-argument-type self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: @@ -1563,6 +1569,7 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc] dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore # bad-argument-type super().__init__( 0, 0, @@ -1577,9 +1584,11 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc] padding_mode, **factory_kwargs, ) + # pyrefly: ignore # bad-override, bad-argument-type self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: + # pyrefly: ignore # bad-override, bad-argument-type self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: @@ -1633,6 +1642,7 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc] dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore # bad-argument-type super().__init__( 0, 0, @@ -1647,9 +1657,11 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc] padding_mode, **factory_kwargs, ) + # pyrefly: ignore # bad-override, bad-argument-type self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: + # pyrefly: ignore # bad-override, bad-argument-type self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: @@ -1701,6 +1713,7 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore # bad-argument-type super().__init__( 0, 0, @@ -1716,9 +1729,11 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi padding_mode, **factory_kwargs, ) + # pyrefly: ignore # bad-override, bad-argument-type self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: + # pyrefly: ignore # bad-override, bad-argument-type self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: @@ -1770,6 +1785,7 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore # bad-argument-type super().__init__( 0, 0, @@ -1785,9 +1801,11 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi padding_mode, **factory_kwargs, ) + # pyrefly: ignore # bad-override, bad-argument-type self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: + # pyrefly: ignore # bad-override, bad-argument-type self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: @@ -1839,6 +1857,7 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} + # pyrefly: ignore # bad-argument-type super().__init__( 0, 0, @@ -1854,9 +1873,11 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi padding_mode, **factory_kwargs, ) + # pyrefly: ignore # bad-override, bad-argument-type self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: + # pyrefly: ignore # bad-override, bad-argument-type self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index 46e7c7be63db..1984eb0d0e15 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -172,7 +172,9 @@ class LazyModuleMixin: def __init__(self: _LazyProtocol, *args, **kwargs): # Mypy doesn't like this super call in a mixin super().__init__(*args, **kwargs) # type: ignore[misc] + # pyrefly: ignore # read-only self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) + # pyrefly: ignore # read-only self._initialize_hook = self.register_forward_pre_hook( self._infer_parameters, with_kwargs=True ) diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 2a2d130590ef..0d17e3174615 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -286,6 +286,7 @@ class LazyLinear(LazyModuleMixin, Linear): """ cls_to_become = Linear # type: ignore[assignment] + # pyrefly: ignore # bad-override weight: UninitializedParameter bias: UninitializedParameter # type: ignore[assignment] @@ -295,16 +296,20 @@ class LazyLinear(LazyModuleMixin, Linear): factory_kwargs = {"device": device, "dtype": dtype} # bias is hardcoded to False to avoid creating tensor # that will soon be overwritten. + # pyrefly: ignore # bad-argument-type super().__init__(0, 0, False) + # pyrefly: ignore # bad-argument-type self.weight = UninitializedParameter(**factory_kwargs) self.out_features = out_features if bias: + # pyrefly: ignore # bad-argument-type self.bias = UninitializedParameter(**factory_kwargs) def reset_parameters(self) -> None: """ Resets parameters based on their initialization used in ``__init__``. """ + # pyrefly: ignore # bad-argument-type if not self.has_uninitialized_params() and self.in_features != 0: super().reset_parameters() @@ -312,6 +317,7 @@ class LazyLinear(LazyModuleMixin, Linear): """ Infers ``in_features`` based on ``input`` and initializes parameters. """ + # pyrefly: ignore # bad-argument-type if self.has_uninitialized_params(): with torch.no_grad(): self.in_features = input.shape[-1] diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 531315646e9f..09484c89c0ee 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -38,11 +38,13 @@ T = TypeVar("T", bound="Module") class _IncompatibleKeys( + # pyrefly: ignore # invalid-inheritance namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), ): __slots__ = () def __repr__(self) -> str: + # pyrefly: ignore # missing-attribute if not self.missing_keys and not self.unexpected_keys: return "" return super().__repr__() @@ -91,6 +93,7 @@ class _WrappedHook: def __getstate__(self) -> dict: result = {"hook": self.hook, "with_module": self.with_module} if self.with_module: + # pyrefly: ignore # unsupported-operation result["module"] = self.module() return result @@ -976,7 +979,9 @@ class Module: # Decrement use count of the gradient by setting to None param.grad = None param_applied = torch.nn.Parameter( - param_applied, requires_grad=param.requires_grad + # pyrefly: ignore # bad-argument-type + param_applied, + requires_grad=param.requires_grad, ) torch.utils.swap_tensors(param, param_applied) except Exception as e: @@ -987,11 +992,13 @@ class Module: ) from e out_param = param elif p_should_use_set_data: + # pyrefly: ignore # bad-assignment param.data = param_applied out_param = param else: assert isinstance(param, Parameter) assert param.is_leaf + # pyrefly: ignore # bad-argument-type out_param = Parameter(param_applied, param.requires_grad) self._parameters[key] = out_param @@ -2253,6 +2260,7 @@ class Module: if destination is None: destination = OrderedDict() + # pyrefly: ignore # missing-attribute destination._metadata = OrderedDict() local_metadata = dict(version=self._version) @@ -2402,7 +2410,9 @@ class Module: if k not in self._non_persistent_buffers_set } local_name_params = itertools.chain( - self._parameters.items(), persistent_buffers.items() + self._parameters.items(), + # pyrefly: ignore # bad-argument-type + persistent_buffers.items(), ) local_state = {k: v for k, v in local_name_params if v is not None} assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 6c4c117d1a7d..2300a498acaa 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -84,6 +84,7 @@ class CircularPad1d(_CircularPadNd): [5., 6., 7., 4., 5., 6., 7., 4.]]]) """ + # pyrefly: ignore # bad-override padding: tuple[int, int] def __init__(self, padding: _size_2_t) -> None: @@ -144,6 +145,7 @@ class CircularPad2d(_CircularPadNd): [8., 6., 7., 8., 6.]]]]) """ + # pyrefly: ignore # bad-override padding: tuple[int, int, int, int] def __init__(self, padding: _size_4_t) -> None: @@ -194,6 +196,7 @@ class CircularPad3d(_CircularPadNd): >>> output = m(input) """ + # pyrefly: ignore # bad-override padding: tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t) -> None: @@ -265,6 +268,7 @@ class ConstantPad1d(_ConstantPadNd): [ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]]) """ + # pyrefly: ignore # bad-override padding: tuple[int, int] def __init__(self, padding: _size_2_t, value: float) -> None: @@ -316,6 +320,7 @@ class ConstantPad2d(_ConstantPadNd): """ __constants__ = ["padding", "value"] + # pyrefly: ignore # bad-override padding: tuple[int, int, int, int] def __init__(self, padding: _size_4_t, value: float) -> None: @@ -356,6 +361,7 @@ class ConstantPad3d(_ConstantPadNd): >>> output = m(input) """ + # pyrefly: ignore # bad-override padding: tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t, value: float) -> None: @@ -409,6 +415,7 @@ class ReflectionPad1d(_ReflectionPadNd): [7., 6., 5., 4., 5., 6., 7., 6.]]]) """ + # pyrefly: ignore # bad-override padding: tuple[int, int] def __init__(self, padding: _size_2_t) -> None: @@ -462,6 +469,7 @@ class ReflectionPad2d(_ReflectionPadNd): [7., 6., 7., 8., 7.]]]]) """ + # pyrefly: ignore # bad-override padding: tuple[int, int, int, int] def __init__(self, padding: _size_4_t) -> None: @@ -517,6 +525,7 @@ class ReflectionPad3d(_ReflectionPadNd): [1., 0., 1., 0.]]]]]) """ + # pyrefly: ignore # bad-override padding: tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t) -> None: @@ -570,6 +579,7 @@ class ReplicationPad1d(_ReplicationPadNd): [4., 4., 4., 4., 5., 6., 7., 7.]]]) """ + # pyrefly: ignore # bad-override padding: tuple[int, int] def __init__(self, padding: _size_2_t) -> None: @@ -623,6 +633,7 @@ class ReplicationPad2d(_ReplicationPadNd): [6., 6., 7., 8., 8.]]]]) """ + # pyrefly: ignore # bad-override padding: tuple[int, int, int, int] def __init__(self, padding: _size_4_t) -> None: @@ -665,6 +676,7 @@ class ReplicationPad3d(_ReplicationPadNd): >>> output = m(input) """ + # pyrefly: ignore # bad-override padding: tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t) -> None: diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index be48f9e7190f..84489a76aa98 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -111,6 +111,7 @@ class RNNBase(Module): if ( not isinstance(dropout, numbers.Number) + # pyrefly: ignore # unsupported-operation or not 0 <= dropout <= 1 or isinstance(dropout, bool) ): @@ -119,6 +120,7 @@ class RNNBase(Module): "representing the probability of an element being " "zeroed" ) + # pyrefly: ignore # unsupported-operation if dropout > 0 and num_layers == 1: warnings.warn( "dropout option adds dropout after all but last " @@ -639,15 +641,22 @@ class RNN(RNNBase): @overload @torch._jit_internal._overload_method # noqa: F811 + # pyrefly: ignore # bad-override def forward( - self, input: Tensor, hx: Optional[Tensor] = None + self, + input: Tensor, + hx: Optional[Tensor] = None, + # pyrefly: ignore # bad-return ) -> tuple[Tensor, Tensor]: pass @overload @torch._jit_internal._overload_method # noqa: F811 def forward( - self, input: PackedSequence, hx: Optional[Tensor] = None + self, + input: PackedSequence, + hx: Optional[Tensor] = None, + # pyrefly: ignore # bad-return ) -> tuple[PackedSequence, Tensor]: pass @@ -772,7 +781,11 @@ class RNN(RNNBase): if isinstance(orig_input, PackedSequence): output_packed = PackedSequence( - output, batch_sizes, sorted_indices, unsorted_indices + output, + # pyrefly: ignore # bad-argument-type + batch_sizes, + sorted_indices, + unsorted_indices, ) return output_packed, self.permute_hidden(hidden, unsorted_indices) @@ -996,6 +1009,7 @@ class LSTM(RNNBase): # In the future, we should prevent mypy from applying contravariance rules here. # See torch/nn/modules/module.py::_forward_unimplemented + # pyrefly: ignore # bad-override def check_forward_args( self, input: Tensor, @@ -1029,8 +1043,12 @@ class LSTM(RNNBase): # Same as above, see torch/nn/modules/module.py::_forward_unimplemented @overload # type: ignore[override] @torch._jit_internal._overload_method # noqa: F811 + # pyrefly: ignore # bad-override def forward( - self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + self, + input: Tensor, + hx: Optional[tuple[Tensor, Tensor]] = None, + # pyrefly: ignore # bad-return ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811 pass @@ -1038,7 +1056,10 @@ class LSTM(RNNBase): @overload @torch._jit_internal._overload_method # noqa: F811 def forward( - self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None + self, + input: PackedSequence, + hx: Optional[tuple[Tensor, Tensor]] = None, + # pyrefly: ignore # bad-return ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811 pass @@ -1152,7 +1173,11 @@ class LSTM(RNNBase): # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): output_packed = PackedSequence( - output, batch_sizes, sorted_indices, unsorted_indices + output, + # pyrefly: ignore # bad-argument-type + batch_sizes, + sorted_indices, + unsorted_indices, ) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: @@ -1318,15 +1343,22 @@ class GRU(RNNBase): @overload # type: ignore[override] @torch._jit_internal._overload_method # noqa: F811 + # pyrefly: ignore # bad-override def forward( - self, input: Tensor, hx: Optional[Tensor] = None + self, + input: Tensor, + hx: Optional[Tensor] = None, + # pyrefly: ignore # bad-return ) -> tuple[Tensor, Tensor]: # noqa: F811 pass @overload @torch._jit_internal._overload_method # noqa: F811 def forward( - self, input: PackedSequence, hx: Optional[Tensor] = None + self, + input: PackedSequence, + hx: Optional[Tensor] = None, + # pyrefly: ignore # bad-return ) -> tuple[PackedSequence, Tensor]: # noqa: F811 pass @@ -1420,7 +1452,11 @@ class GRU(RNNBase): # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): output_packed = PackedSequence( - output, batch_sizes, sorted_indices, unsorted_indices + output, + # pyrefly: ignore # bad-argument-type + batch_sizes, + sorted_indices, + unsorted_indices, ) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index dbee65296660..d5f489c7c56a 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -135,7 +135,11 @@ class Transformer(Module): **factory_kwargs, ) encoder_norm = LayerNorm( - d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs + d_model, + eps=layer_norm_eps, + bias=bias, + # pyrefly: ignore # bad-argument-type + **factory_kwargs, ) self.encoder = TransformerEncoder( encoder_layer, num_encoder_layers, encoder_norm @@ -157,7 +161,11 @@ class Transformer(Module): **factory_kwargs, ) decoder_norm = LayerNorm( - d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs + d_model, + eps=layer_norm_eps, + bias=bias, + # pyrefly: ignore # bad-argument-type + **factory_kwargs, ) self.decoder = TransformerDecoder( decoder_layer, num_decoder_layers, decoder_norm @@ -760,7 +768,9 @@ class TransformerEncoderLayer(Module): self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) self.norm_first = norm_first + # pyrefly: ignore # bad-argument-type self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + # pyrefly: ignore # bad-argument-type self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) @@ -1052,8 +1062,11 @@ class TransformerDecoderLayer(Module): self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) self.norm_first = norm_first + # pyrefly: ignore # bad-argument-type self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + # pyrefly: ignore # bad-argument-type self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + # pyrefly: ignore # bad-argument-type self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 220b8f206b19..d8d8783b06b4 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -36,6 +36,7 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]: import torch if isinstance(out_size, (int, torch.SymInt)): + # pyrefly: ignore # bad-return return out_size if len(defaults) <= len(out_size): raise ValueError(f"Input dimension should be at least {len(out_size) + 1}") diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index 42b3dbd908d6..01ed3030fb84 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -43,6 +43,7 @@ def broadcast(tensor, devices=None, *, out=None): devices = [_get_device_index(d) for d in devices] return torch._C._broadcast(tensor, devices) else: + # pyrefly: ignore # bad-argument-type return torch._C._broadcast_out(tensor, out) @@ -200,6 +201,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out= """ tensor = _handle_complex(tensor) if out is None: + # pyrefly: ignore # not-iterable devices = [_get_device_index(d) for d in devices] return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams)) else: diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index 16bdc204a6bf..22cc3044c221 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -160,6 +160,7 @@ class DataParallel(Module, Generic[T]): self.module = module self.device_ids = [_get_device_index(x, True) for x in device_ids] self.output_device = _get_device_index(output_device, True) + # pyrefly: ignore # read-only self.src_device_obj = torch.device(device_type, self.device_ids[0]) if device_type == "cuda": @@ -173,6 +174,7 @@ class DataParallel(Module, Generic[T]): if not self.device_ids: return self.module(*inputs, **kwargs) + # pyrefly: ignore # bad-argument-type for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: raise RuntimeError( @@ -259,8 +261,10 @@ def data_parallel( device_ids = [_get_device_index(x, True) for x in device_ids] output_device = _get_device_index(output_device, True) + # pyrefly: ignore # no-matching-overload src_device_obj = torch.device(device_type, device_ids[0]) + # pyrefly: ignore # bad-argument-type for t in chain(module.parameters(), module.buffers()): if t.device != src_device_obj: raise RuntimeError( diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 040d49e17dcc..eeb37b389436 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -241,6 +241,7 @@ class _BufferCommHook: # is completed. class _DDPSink(Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx, ddp_weakref, *inputs): # set_materialize_grads(False) will ensure that None gradients stay as # None and are not filled with zeros. @@ -691,6 +692,7 @@ class DistributedDataParallel(Module, Joinable): elif process_group is None and device_mesh is None: self.process_group = _get_default_group() elif device_mesh is None: + # pyrefly: ignore # bad-assignment self.process_group = process_group else: if device_mesh.ndim != 1: @@ -779,11 +781,13 @@ class DistributedDataParallel(Module, Joinable): self.device_ids = None self.output_device = None else: + # pyrefly: ignore # bad-assignment self.device_ids = [_get_device_index(x, True) for x in device_ids] if output_device is None: output_device = device_ids[0] + # pyrefly: ignore # bad-assignment self.output_device = _get_device_index(output_device, True) self.static_graph = False @@ -933,6 +937,7 @@ class DistributedDataParallel(Module, Joinable): # enabled. self._accum_grad_hooks: list[RemovableHandle] = [] if self._use_python_reducer: + # pyrefly: ignore # bad-assignment torch._inductor.config._fuse_ddp_communication = True torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb # Directly adding this to the trace rule will disturb the users diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index 947f56357365..cb167b80b809 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -56,12 +56,16 @@ def scatter(inputs, target_gpus, dim=0): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) if _is_namedtuple(obj): + # pyrefly: ignore # no-matching-overload return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] if isinstance(obj, tuple) and len(obj) > 0: + # pyrefly: ignore # no-matching-overload return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: + # pyrefly: ignore # no-matching-overload return [list(i) for i in zip(*map(scatter_map, obj))] if isinstance(obj, dict) and len(obj) > 0: + # pyrefly: ignore # no-matching-overload return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))] return [obj for _ in target_gpus] @@ -123,9 +127,12 @@ def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) if isinstance(out, dict): if not all(len(out) == len(d) for d in outputs): raise ValueError("All dicts must have the same number of keys") + # pyrefly: ignore # not-callable return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) if _is_namedtuple(out): + # pyrefly: ignore # no-matching-overload return type(out)._make(map(gather_map, zip(*outputs))) + # pyrefly: ignore # no-matching-overload return type(out)(map(gather_map, zip(*outputs))) # Recursive function calls like this create reference cycles. diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index c41a102fc946..39758f3efd15 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -81,6 +81,7 @@ class Parameter(torch.Tensor, metaclass=_ParameterMeta): memo[id(self)] = result return result + # pyrefly: ignore # bad-override def __repr__(self): return "Parameter containing:\n" + super().__repr__() @@ -143,6 +144,7 @@ class UninitializedTensorMixin: if dtype is None: dtype = self.data.dtype self.data = torch.empty(shape, device=device, dtype=dtype) + # pyrefly: ignore # bad-override, missing-attribute self.__class__ = self.cls_to_become @property @@ -166,6 +168,7 @@ class UninitializedTensorMixin: def __reduce_ex__(self, proto): # See Note [Don't serialize hooks] + # pyrefly: ignore # missing-attribute return (self.__class__, (self.requires_grad,)) @classmethod @@ -175,6 +178,7 @@ class UninitializedTensorMixin: if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": if kwargs is None: kwargs = {} + # pyrefly: ignore # missing-attribute return super().__torch_function__(func, types, args, kwargs) raise ValueError( f"Attempted to use an uninitialized parameter in {func}. " @@ -216,6 +220,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter): def __new__(cls, requires_grad=True, device=None, dtype=None) -> None: factory_kwargs = {"device": device, "dtype": dtype} data = torch.empty(0, **factory_kwargs) + # pyrefly: ignore # bad-return return torch.Tensor._make_subclass(cls, data, requires_grad) def __deepcopy__(self, memo): @@ -261,7 +266,9 @@ class Buffer(torch.Tensor, metaclass=_BufferMeta): data = torch.empty(0) t = data.detach().requires_grad_(data.requires_grad) + # pyrefly: ignore # missing-attribute t.persistent = persistent + # pyrefly: ignore # missing-attribute t._is_buffer = True return t @@ -292,6 +299,9 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): factory_kwargs = {"device": device, "dtype": dtype} data = torch.empty(0, **factory_kwargs) ret = torch.Tensor._make_subclass(cls, data, requires_grad) + # pyrefly: ignore # missing-attribute ret.persistent = persistent + # pyrefly: ignore # missing-attribute ret._is_buffer = True + # pyrefly: ignore # bad-return return ret diff --git a/torch/nn/utils/__init__.py b/torch/nn/utils/__init__.py index 84145da93f7b..ed9a83b13389 100644 --- a/torch/nn/utils/__init__.py +++ b/torch/nn/utils/__init__.py @@ -1,5 +1,5 @@ from . import parametrizations, parametrize, rnn, stateless -from .clip_grad import ( +from .clip_grad import ( # pyrefly: ignore # deprecated _clip_grads_with_norm_ as clip_grads_with_norm_, _get_total_norm as get_total_norm, clip_grad_norm, diff --git a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py index 9ccbc774612e..dba0cd27132d 100644 --- a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py @@ -24,6 +24,7 @@ from .expanded_weights_utils import forward_helper @implements_per_sample_grads(F.conv3d) class ConvPerSampleGrad(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward( ctx: Any, kwarg_names: list[str], @@ -56,6 +57,7 @@ class ConvPerSampleGrad(torch.autograd.Function): f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}" ) + # pyrefly: ignore # invalid-type-var ctx.conv_fn = conv_fn ctx.batch_size = orig_input.shape[0] diff --git a/torch/nn/utils/_expanded_weights/conv_utils.py b/torch/nn/utils/_expanded_weights/conv_utils.py index 74418e143860..463d7efb6467 100644 --- a/torch/nn/utils/_expanded_weights/conv_utils.py +++ b/torch/nn/utils/_expanded_weights/conv_utils.py @@ -237,6 +237,7 @@ def conv_unfold_weight_grad_sample( # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input) # rearrange the above tensor and extract diagonals. + # pyrefly: ignore # no-matching-overload weight_grad_sample = weight_grad_sample.view( n, groups, diff --git a/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py b/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py index 24acd7549e6d..e1c9dc04d8cf 100644 --- a/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py @@ -14,6 +14,7 @@ from .expanded_weights_utils import ( @implements_per_sample_grads(F.embedding) class EmbeddingPerSampleGrad(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward( ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any ) -> torch.Tensor: @@ -34,6 +35,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function): return output @staticmethod + # pyrefly: ignore # bad-override def backward( ctx: Any, grad_output: torch.Tensor ) -> tuple[Optional[torch.Tensor], ...]: diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index 3a5b99c41c65..dd6c6107fe22 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -131,7 +131,9 @@ class ExpandedWeight(torch.Tensor): # in aten, choosing the input or data variants is done by parsing logic. This mimics some of that decomp_opts = expanded_weights_rnn_decomps[func] use_input_variant = isinstance( - args[2], list + # pyrefly: ignore # index-error + args[2], + list, ) # data variant uses a list here decomp = decomp_opts[0] if use_input_variant else decomp_opts[1] diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index 1249adfd7594..5f99e468767d 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -8,6 +8,7 @@ from .expanded_weights_impl import ExpandedWeight def is_batch_first(expanded_args_and_kwargs): batch_first = None + # pyrefly: ignore # bad-assignment for arg in expanded_args_and_kwargs: if not isinstance(arg, ExpandedWeight): continue diff --git a/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py index 913bc6cce7b5..1439593408c8 100644 --- a/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py @@ -18,6 +18,7 @@ from .expanded_weights_utils import ( @implements_per_sample_grads(F.group_norm) class GroupNormPerSampleGrad(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): expanded_args, expanded_kwargs = standard_kwargs( kwarg_names, expanded_args_and_kwargs @@ -46,6 +47,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function): return output @staticmethod + # pyrefly: ignore # bad-override def backward(ctx, grad_output): input, num_groups = ctx.input, ctx.num_groups weight, bias, eps = ctx.weight, ctx.bias, ctx.eps @@ -94,7 +96,9 @@ class GroupNormPerSampleGrad(torch.autograd.Function): set_grad_sample_if_exists( weight, lambda _: torch.einsum( - "ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output + "ni...->ni", + # pyrefly: ignore # unsupported-operation + F.group_norm(input, num_groups, eps=eps) * grad_output, ), ) if hasattr(ctx, "bias"): diff --git a/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py index 586e29a40f95..7f7fc02dc905 100644 --- a/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py @@ -17,6 +17,7 @@ from .expanded_weights_utils import ( @implements_per_sample_grads(F.instance_norm) class InstanceNormPerSampleGrad(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): instance_norm = partial(torch.instance_norm, cudnn_enabled=True) expanded_args, expanded_kwargs = standard_kwargs( @@ -36,6 +37,7 @@ class InstanceNormPerSampleGrad(torch.autograd.Function): return output @staticmethod + # pyrefly: ignore # bad-override def backward(ctx, grad_output): input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var weight, bias, eps = ctx.weight, ctx.bias, ctx.eps diff --git a/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py index f223f97460a1..a53ee8a52dab 100644 --- a/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py @@ -17,6 +17,7 @@ from .expanded_weights_utils import ( @implements_per_sample_grads(F.layer_norm) class LayerNormPerSampleGrad(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): expanded_args, expanded_kwargs = standard_kwargs( kwarg_names, expanded_args_and_kwargs @@ -42,6 +43,7 @@ class LayerNormPerSampleGrad(torch.autograd.Function): return output @staticmethod + # pyrefly: ignore # bad-override def backward(ctx, grad_output): def weight_per_sample_grad(weight): return sum_over_all_but_batch_and_last_n( diff --git a/torch/nn/utils/_expanded_weights/linear_expanded_weights.py b/torch/nn/utils/_expanded_weights/linear_expanded_weights.py index 25b544ed7826..e617c79bb1c4 100644 --- a/torch/nn/utils/_expanded_weights/linear_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/linear_expanded_weights.py @@ -16,6 +16,7 @@ from .expanded_weights_utils import ( @implements_per_sample_grads(F.linear) class LinearPerSampleGrad(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx, _, __, *expanded_args_and_kwargs): if len(expanded_args_and_kwargs[0].shape) <= 1: raise RuntimeError( @@ -35,6 +36,7 @@ class LinearPerSampleGrad(torch.autograd.Function): return output @staticmethod + # pyrefly: ignore # bad-override def backward(ctx, grad_output): input, weight = ctx.args bias = ctx.kwargs["bias"] diff --git a/torch/nn/utils/_named_member_accessor.py b/torch/nn/utils/_named_member_accessor.py index 318eb2258ecc..7178b11d00d8 100644 --- a/torch/nn/utils/_named_member_accessor.py +++ b/torch/nn/utils/_named_member_accessor.py @@ -77,6 +77,7 @@ def swap_tensor( setattr(module, name, tensor) elif hasattr(module, name): delattr(module, name) + # pyrefly: ignore # bad-return return orig_tensor diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 4aefe32d48ed..9d6cc2a2b691 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -41,9 +41,11 @@ def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]: def _no_grad_wrapper(*args, **kwargs): with torch.no_grad(): + # pyrefly: ignore # invalid-param-spec return func(*args, **kwargs) functools.update_wrapper(_no_grad_wrapper, func) + # pyrefly: ignore # bad-return return _no_grad_wrapper diff --git a/torch/nn/utils/memory_format.py b/torch/nn/utils/memory_format.py index 59e54b11e3b9..757b0bb272c8 100644 --- a/torch/nn/utils/memory_format.py +++ b/torch/nn/utils/memory_format.py @@ -84,6 +84,7 @@ def convert_conv2d_weight_memory_format( ) for child in module.children(): convert_conv2d_weight_memory_format(child, memory_format) + # pyrefly: ignore # bad-return return module @@ -163,6 +164,7 @@ def convert_conv3d_weight_memory_format( ) for child in module.children(): convert_conv3d_weight_memory_format(child, memory_format) + # pyrefly: ignore # bad-return return module diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index 5a371af995b6..e93458495617 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -98,6 +98,7 @@ class _Orthogonal(Module): ) # Q is now orthogonal (or unitary) of size (..., n, n) if n != k: + # pyrefly: ignore # unbound-name Q = Q[..., :k] # Q is now the size of the X (albeit perhaps transposed) else: diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index 25de247c6df6..ed298dece3ac 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -179,23 +179,28 @@ class ParametrizationList(ModuleList): # Register the tensor(s) if self.is_tensor: + # pyrefly: ignore # missing-attribute if original.dtype != new.dtype: raise ValueError( "When `right_inverse` outputs one tensor, it may not change the dtype.\n" f"original.dtype: {original.dtype}\n" + # pyrefly: ignore # missing-attribute f"right_inverse(original).dtype: {new.dtype}" ) + # pyrefly: ignore # missing-attribute if original.device != new.device: raise ValueError( "When `right_inverse` outputs one tensor, it may not change the device.\n" f"original.device: {original.device}\n" + # pyrefly: ignore # missing-attribute f"right_inverse(original).device: {new.device}" ) # Set the original to original so that the user does not need to re-register the parameter # manually in the optimiser with torch.no_grad(): + # pyrefly: ignore # bad-argument-type _maybe_set(original, new) _register_parameter_or_buffer(self, "original", original) else: @@ -396,6 +401,7 @@ def _inject_property(module: Module, tensor_name: str) -> None: if torch.jit.is_scripting(): raise RuntimeError("Parametrization is not working with scripting.") parametrization = self.parametrizations[tensor_name] + # pyrefly: ignore # redundant-condition if _cache_enabled: if torch.jit.is_scripting(): # Scripting @@ -695,6 +701,7 @@ def remove_parametrizations( # Fetch the original tensor assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy parametrizations = module.parametrizations[tensor_name] + # pyrefly: ignore # invalid-argument if parametrizations.is_tensor: original = parametrizations.original assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor" diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index aee6bdc2ad21..aa0d5c2e7248 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -274,8 +274,11 @@ class PruningContainer(BasePruningMethod): if not isinstance(args, Iterable): # only 1 item self._tensor_name = args._tensor_name self.add_pruning_method(args) + # pyrefly: ignore # bad-argument-type elif len(args) == 1: # only 1 item in a tuple + # pyrefly: ignore # index-error self._tensor_name = args[0]._tensor_name + # pyrefly: ignore # index-error self.add_pruning_method(args[0]) else: # manual construction from list or other iterable (or no args) for method in args: @@ -1097,6 +1100,7 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw # flatten importance scores to consider them all at once in global pruning relevant_importance_scores = torch.nn.utils.parameters_to_vector( + # pyrefly: ignore # bad-argument-type [ importance_scores.get((module, name), getattr(module, name)) for (module, name) in parameters diff --git a/torch/nn/utils/spectral_norm.py b/torch/nn/utils/spectral_norm.py index a1eeb87c24ab..9cf39cc5bda7 100644 --- a/torch/nn/utils/spectral_norm.py +++ b/torch/nn/utils/spectral_norm.py @@ -332,6 +332,7 @@ def spectral_norm( else: dim = 0 SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + # pyrefly: ignore # bad-return return module diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 47f9ca084d79..75559795302a 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -41,10 +41,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable: ) -# pyrefly: ignore # import-error import optree - -# pyrefly: ignore # import-error from optree import PyTreeSpec as TreeSpec # direct import for type annotations diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index 127bed2fc103..c665bb634c5f 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -679,7 +679,7 @@ class FlopCounterMode: import tabulate - # pyrefly: ignore # bad-assignment + tabulate.PRESERVE_WHITESPACE = True header = ["Module", "FLOP", "% Total"] values = [] diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index 682848b12d9e..f78d2906779d 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -9,7 +9,7 @@ from typing import Any, Optional import torch import numpy as np -# pyrefly: ignore # import-error + from google.protobuf import struct_pb2 from tensorboard.compat.proto.summary_pb2 import ( diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index 8add89f236b6..6153b0033681 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -956,7 +956,7 @@ class SummaryWriter: ) self._projector_config.embeddings.extend([embedding_info]) - # pyrefly: ignore # import-error + from google.protobuf import text_format config_pbtxt = text_format.MessageToString(self._projector_config)