diff --git a/pyrefly.toml b/pyrefly.toml index d4146bf88d4a..73b0e9d28122 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -24,12 +24,12 @@ project-excludes = [ "torch/distributed/**", "torch/nn/**", "torch/_dynamo/**", - "torch/utils/**", # formatting issues "torch/linalg/__init__.py", "torch/package/importer.py", "torch/package/_package_pickler.py", "torch/jit/annotations.py", + "torch/utils/data/datapipes/_typing.py", # ==== "benchmarks/instruction_counts/main.py", "benchmarks/instruction_counts/definitions/setup.py", diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index 221502ae3190..f1c2db025383 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -59,6 +59,7 @@ class TestBundledInputs(TestCase): # despite having nominally large bundled inputs. augmented_size = model_size(sm) + # pyrefly: ignore # missing-attribute self.assertLess(augmented_size, original_size + (1 << 12)) loaded = save_and_load(sm) @@ -66,12 +67,15 @@ class TestBundledInputs(TestCase): self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) self.assertEqual(len(inflated), len(samples)) + # pyrefly: ignore # missing-attribute self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) for idx, inp in enumerate(inflated): + # pyrefly: ignore # missing-attribute self.assertIsInstance(inp, tuple) self.assertEqual(len(inp), 1) + # pyrefly: ignore # missing-attribute self.assertIsInstance(inp[0], torch.Tensor) if idx != 5: # Strides might be important for benchmarking. @@ -140,6 +144,7 @@ class TestBundledInputs(TestCase): inflated = loaded.get_all_bundled_inputs() self.assertEqual(inflated, samples) + # pyrefly: ignore # missing-attribute self.assertTrue(loaded(*inflated[0]) == "first 1") def test_multiple_methods_with_inputs(self): @@ -187,6 +192,7 @@ class TestBundledInputs(TestCase): # Check running and size helpers + # pyrefly: ignore # missing-attribute self.assertTrue(loaded(*inflated[0]) is inflated[0][0]) self.assertEqual(loaded.get_num_bundled_inputs(), len(samples)) @@ -420,6 +426,7 @@ class TestBundledInputs(TestCase): augmented_size = model_size(sm) # assert the size has not increased more than 8KB + # pyrefly: ignore # missing-attribute self.assertLess(augmented_size, original_size + (1 << 13)) loaded = save_and_load(sm) diff --git a/test/test_complex.py b/test/test_complex.py index 159f3e18aaee..972fe3f0fd1c 100644 --- a/test/test_complex.py +++ b/test/test_complex.py @@ -48,6 +48,7 @@ class TestComplexTensor(TestCase): def test_all(self, device, dtype): # issue: https://github.com/pytorch/pytorch/issues/120875 x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype) + # pyrefly: ignore # missing-attribute self.assertTrue(torch.all(x)) @dtypes(*complex_types()) @@ -56,6 +57,7 @@ class TestComplexTensor(TestCase): x = torch.tensor( [0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype ) + # pyrefly: ignore # missing-attribute self.assertFalse(torch.any(x)) @onlyCPU diff --git a/test/test_type_hints.py b/test/test_type_hints.py index 0aae54be9b63..c982ae19b6df 100644 --- a/test/test_type_hints.py +++ b/test/test_type_hints.py @@ -142,6 +142,7 @@ class TestTypeHints(TestCase): ] ) if result != 0: + # pyrefly: ignore # missing-attribute self.fail(f"mypy failed:\n{stderr}\n{stdout}") diff --git a/test/test_type_info.py b/test/test_type_info.py index 80a21bc5e9dd..48dc083fed1e 100644 --- a/test/test_type_info.py +++ b/test/test_type_info.py @@ -125,6 +125,7 @@ class TestDTypeInfo(TestCase): # Regression test for https://github.com/pytorch/pytorch/issues/124868 # If reference count is leaked this would be a set of 10 elements ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)} + # pyrefly: ignore # missing-attribute self.assertLess(len(ref_cnt), 3) self.assertEqual(torch.float64.to_complex(), torch.complex128) @@ -135,6 +136,7 @@ class TestDTypeInfo(TestCase): # Regression test for https://github.com/pytorch/pytorch/issues/124868 # If reference count is leaked this would be a set of 10 elements ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)} + # pyrefly: ignore # missing-attribute self.assertLess(len(ref_cnt), 3) self.assertEqual(torch.complex128.to_real(), torch.double) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 70b3e161cebc..c4396932818d 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -240,7 +240,7 @@ def get_decompositions( registry = global_decomposition_table[type] packets_to_overloads = defaultdict(list) - # pyrefly: ignore # bad-assignment + for opo in registry: if isinstance(opo, (OpOverload, OpOverloadPacket)): packets_to_overloads[opo.overloadpacket].append(opo) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 41f0ff84d1d6..fc2570999686 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4079,6 +4079,7 @@ def _nll_loss_forward( return result, total_weight if weight is not None: + # pyrefly: ignore # unbound-name w = w.expand(self.shape) wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) wsum = torch.where(target != ignore_index, wsum, 0) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 26cf4796fd07..e463023caa77 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -341,8 +341,8 @@ class MtiaInterface(DeviceInterface): synchronize = staticmethod(torch.mtia.synchronize) get_device_properties = staticmethod(torch.mtia.get_device_properties) # type: ignore[assignment] get_raw_stream = staticmethod(get_mtia_stream) # type: ignore[assignment, arg-type] - exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type] - maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type] + exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type, has-type] + maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type, has-type] memory_allocated = staticmethod(torch.mtia.memory_allocated) # type: ignore[assignment] is_bf16_supported = staticmethod(torch.mtia.is_bf16_supported) # type: ignore[arg-type] @@ -414,7 +414,7 @@ class XpuInterface(DeviceInterface): current_device = staticmethod(torch.xpu.current_device) set_device = staticmethod(torch.xpu.set_device) - device_count = staticmethod(torch.xpu.device_count) + device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type] stream = staticmethod(torch.xpu.stream) # type: ignore[assignment] current_stream = staticmethod(torch.xpu.current_stream) set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment] @@ -422,8 +422,8 @@ class XpuInterface(DeviceInterface): synchronize = staticmethod(torch.xpu.synchronize) get_device_properties = staticmethod(torch.xpu.get_device_properties) # type: ignore[assignment] get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[assignment, arg-type] - exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type] - maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type] + exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type, has-type] + maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type, has-type] memory_allocated = staticmethod(torch.xpu.memory_allocated) # Can be mock patched by @patch decorator. diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 8749c62e6695..e2a3be171188 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -1097,7 +1097,6 @@ class TS2FXGraphConverter: # Update the value of loop local variables. if node.outputsSize() >= 1: - # pyrefly: ignore # bad-assignment for i, outp in enumerate(node.outputs()): output_name = outp.debugName() self.name_to_node[output_name] = self.fx_graph.call_function( @@ -1110,7 +1109,7 @@ class TS2FXGraphConverter: fx_block_args[i] = self.name_to_node[output_name] # Update the value of global variables, whose values are modified inplace. - # pyrefly: ignore # bad-assignment + for i, name in enumerate( subgraph_converter.name_update_from_subblock_to_parent ): diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 6e9cca25355a..055a2c7de048 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -140,7 +140,7 @@ def key_path_to_source( source: Source = LocalSource("args") else: source, kp = sourced_prefixes.get(kp) - # pyrefly: ignore # bad-assignment + for k in kp: if isinstance(k, SequenceKey): source = GetItemSource(source, k.idx) diff --git a/torch/_export/pass_base.py b/torch/_export/pass_base.py index 2fb47c0c7d30..b65df30103eb 100644 --- a/torch/_export/pass_base.py +++ b/torch/_export/pass_base.py @@ -317,6 +317,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): ) res_proxy.node.meta.update(meta.data) if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env): + # pyrefly: ignore # unbound-name if symbol_to_path := compute_unbacked_bindings(shape_env, res_data): res_proxy.node.meta["unbacked_bindings"] = symbol_to_path self.tracer.set_metadata(res_proxy.node, res_data) diff --git a/torch/_export/passes/_node_metadata_hook.py b/torch/_export/passes/_node_metadata_hook.py index 950eccdea9df..d82673e58ec0 100644 --- a/torch/_export/passes/_node_metadata_hook.py +++ b/torch/_export/passes/_node_metadata_hook.py @@ -83,7 +83,6 @@ def _node_metadata_hook( node.meta["torch_fn"] = node.meta.get( "torch_fn", ( - # pyrefly: ignore # missing-attribute f"{node.target.__name__}_0", # pyrefly: ignore # missing-attribute f"{node.target.__class__.__name__}.{node.target.__name__}", diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index 416619cee029..b99d2667c3a9 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -646,6 +646,7 @@ def update_schema(): assert thrift_content[1].startswith("// checksum<<") thrift_checksum_real = _hash_content("\n".join(thrift_content[2:])) + # pyrefly: ignore # import-error from yaml import load, Loader dst = load(content, Loader=Loader) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index b82c2c26e382..5c0e2b25ae77 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -2183,6 +2183,7 @@ class GraphModuleDeserializer(metaclass=Final): simplify=True, ) ): + # pyrefly: ignore # unbound-name node.meta["unbacked_bindings"] = unbacked_bindings assert len(self.unbacked_symbols) == 0 diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 939160a48815..a55385425373 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -471,7 +471,6 @@ def _check_input_constraints_for_graph( elif isinstance(node_val, torch.SymInt): _check_symint( node_val, - # pyrefly: ignore # bad-argument-type arg, range_constraints, unification_map, diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 19c55c166151..14409e36dc09 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -204,6 +204,7 @@ def run_functionalized_fw_and_collect_metadata( suppress_pending = contextlib.nullcontext() fake_mode = detect_fake_mode() if fake_mode and (shape_env := fake_mode.shape_env): + # pyrefly: ignore # unbound-name suppress_pending = shape_env.ignore_fresh_unbacked_symbols() with disable_above, mode, suppress_pending: # precondition: The passed in function already handles unflattening inputs + flattening outputs diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 0a3d0e81628a..79cb44a4e348 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -439,7 +439,7 @@ def collect_fw_donated_buffer_idxs( """ storage_refs = set() - # pyrefly: ignore # bad-assignment + for t in itertools.chain(fw_ins, user_fw_outs, bw_outs): # Only access storage if a tensor has storage (not sparse) if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t): diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index a136f51437f0..08b9d869e2ed 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -196,6 +196,7 @@ class MemoryFormatMeta: if use_memory_format: return MemoryFormatMeta( + # pyrefly: ignore # unbound-name memory_format=torch._prims_common.suggest_memory_format(t), ) @@ -892,12 +893,15 @@ class GraphSignature: parameters_to_mutate = {} for output_name, mutation_name in outputs_to_mutations.items(): if mutation_name in user_inputs: + # pyrefly: ignore # unsupported-operation user_inputs_to_mutate[output_name] = mutation_name else: assert mutation_name in buffers or mutation_name in parameters if mutation_name in buffers: + # pyrefly: ignore # unsupported-operation buffers_to_mutate[output_name] = mutation_name else: + # pyrefly: ignore # unsupported-operation parameters_to_mutate[output_name] = mutation_name start, stop = stop, stop + num_user_outputs diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index 21092aebca29..9772ecc6c260 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -232,7 +232,6 @@ def unwrap_tensor_subclasses( attrs, _ = t.__tensor_flatten__() - # pyrefly: ignore # bad-assignment for attr in attrs: inner_tensor = getattr(t, attr) n_desc: Any = ( @@ -314,7 +313,6 @@ def runtime_unwrap_tensor_subclasses( for idx, x in enumerate(wrapped_args): if not is_traceable_wrapper_subclass(x): - # pyrefly: ignore # bad-argument-type xs_inner.append(x) continue diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 0d5ab74e784c..d674cfc0bf47 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -199,6 +199,7 @@ def _extract_graph_with_inputs_outputs( new_node = new_graph.placeholder(node.name) # Can't use node_copy here as we may be turning previous call_function into placeholders new_node.meta = node.meta + # pyrefly: ignore # unsupported-operation env[node] = new_node for node in joint_graph.nodes: @@ -227,8 +228,10 @@ def _extract_graph_with_inputs_outputs( if any(all_args): env[node] = InvalidNode # type: ignore[assignment] continue + # pyrefly: ignore # unsupported-operation, bad-argument-type env[node] = new_graph.node_copy(node, lambda x: env[x]) elif node.op == "get_attr": + # pyrefly: ignore # unsupported-operation, bad-argument-type env[node] = new_graph.node_copy(node, lambda x: env[x]) elif node.op == "output": pass @@ -1403,12 +1406,14 @@ def functionalize_rng_ops( devices = OrderedSet( get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values() ) + # pyrefly: ignore # unbound-name devices.discard(torch.device("cpu")) # multiple cuda devices won't work with cudagraphs anyway, # fallback to non graphsafe rng checkpointing multi_cuda_devices = len(devices) > 1 # this changes numerics, so if fallback_random is set we will not use it + # pyrefly: ignore # unbound-name ind_config = torch._inductor.config use_rng_graphsafe_rng_functionalization = ( config.graphsafe_rng_functionalization @@ -2840,6 +2845,7 @@ def min_cut_rematerialization_partition( node_info, memory_budget=memory_budget, ) + # pyrefly: ignore # unbound-name if config._sync_decision_cross_ranks: saved_values = _sync_decision_cross_ranks(joint_graph, saved_values) # save_for_backward on tensors and stashes symints in autograd .ctx diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index 574fc8ecab41..3fc9e36ed2c9 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -57,6 +57,7 @@ def _interleave(a, b, dim=0): stacked = torch.stack([a, b], dim=dim + 1) interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1) + # pyrefly: ignore # unbound-name if b_trunc: # TODO: find torch alternative for slice_along dim for torch.jit.script to work interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 0543ed2b107e..9cf9fbd0f562 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -746,6 +746,7 @@ class WhileLoopAutogradOp(torch.autograd.Function): and (shape_env := loop_count.node.shape_env) and loop_count in shape_env.pending_fresh_unbacked_symbols ): + # pyrefly: ignore # unbound-name shape_env.pending_fresh_unbacked_symbols.remove(loop_count) # Even when body function is not executed, we clone and unsqueeze the input diff --git a/torch/_library/fake_profile.py b/torch/_library/fake_profile.py index c202ab711926..9a835dcd1dba 100644 --- a/torch/_library/fake_profile.py +++ b/torch/_library/fake_profile.py @@ -198,6 +198,7 @@ def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str: to a file. The yaml string can be loaded back into an operator profile structure using `read_profiles_from_yaml`. """ + # pyrefly: ignore # import-error import yaml from torch._export.serde.serialize import ( @@ -262,6 +263,7 @@ def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]: """ Reads the yaml saved by `save_op_profiles` and returns the operator profiles. """ + # pyrefly: ignore # import-error import yaml from torch._export.serde.serialize import ( diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 3b268e1731db..26f7c0abd528 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -914,6 +914,7 @@ class TorchLogsFormatter(logging.Formatter): and (trace_id := torch._guards.CompileContext.current_trace_id()) is not None ): + # pyrefly: ignore # unbound-name record.traceid = f" [{trace_id}]" glog_level_to_abbr = { diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 306ee78eecdd..f52c8cbfdfb6 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -113,7 +113,6 @@ def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool: if len(a) != len(b): return False - # pyrefly: ignore # bad-assignment for x, y in zip(a, b): if allow_rhs_unbacked: if isinstance(y, torch.SymInt): diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 0b6869622303..f2f87dea36b9 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -6682,7 +6682,7 @@ def _infer_scalar_type(obj): # double. if length == 0: return torch.get_default_dtype() - # pyrefly: ignore # bad-assignment + for i in range(length): cur_item = obj[i] # TODO: test this diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py index 5d8143fb482f..e4e300bee62a 100644 --- a/torch/_refs/fft.py +++ b/torch/_refs/fft.py @@ -106,13 +106,12 @@ def _resize_fft_input( if x_sizes[dims[i]] < sizes[i]: must_copy = True pad_idx = len(pad_amount) - 2 * dims[i] - 1 - # pyrefly: ignore # unsupported-operation + pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]] if x_sizes[dims[i]] > sizes[i]: x = x.narrow(dims[i], 0, sizes[i]) - # pyrefly: ignore # bad-argument-type return torch.constant_pad_nd(x, pad_amount) if must_copy else x diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 2bd91a432dd1..30220be6be5c 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1374,6 +1374,7 @@ class FakeTensorMode(TorchDispatchMode): return self._stack @count + # pyrefly: ignore # bad-override def __torch_dispatch__( self, func: OpOverload, @@ -2624,6 +2625,7 @@ class FakeTensorMode(TorchDispatchMode): and s.rhs == 1 ): assert self.shape_env is not None + # pyrefly: ignore # unbound-name self.shape_env.set_unbacked_var_to_val(s, int(real_t)) if real_out is not nil: diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index d8ddffbcbb63..1dd0adf42ffd 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1820,6 +1820,7 @@ class MetaConverter(Generic[_TensorT]): # TODO: Use a valid grad-specific symbolic context instead of recycling # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view(). + # pyrefly: ignore # unbound-name r.grad = self.meta_tensor( t.grad, shape_env, @@ -1827,12 +1828,15 @@ class MetaConverter(Generic[_TensorT]): AttrSource(source, "grad"), symbolic_context, ) + # pyrefly: ignore # unbound-name torch._C._set_conj(r, t.is_conj) + # pyrefly: ignore # unbound-name torch._C._set_neg(r, t.is_neg) # This can be skipped if necessary for performance reasons skip_leaf = ( t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE ) + # pyrefly: ignore # unbound-name assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf) # Thanks to storage resizing, it's possible to end up with a tensor # that advertises a real size, but has a storage that actually has zero bytes. @@ -1840,14 +1844,18 @@ class MetaConverter(Generic[_TensorT]): from torch.fx.experimental.symbolic_shapes import guard_or_false if t.storage is not None and guard_or_false(t.storage.size == 0): + # pyrefly: ignore # unbound-name r.untyped_storage().resize_(0) if t.is_parameter: + # pyrefly: ignore # unbound-name r._is_param = True # See Note: [Creating symbolic nested int] if t.nested_int is not None: + # pyrefly: ignore # unbound-name assert _is_fake_tensor(r) + # pyrefly: ignore # unbound-name r.nested_int_memo = r.fake_mode.create_symbolic_nested_int( nt_tensor_id=t.nested_int ) diff --git a/torch/_tensor.py b/torch/_tensor.py index 52e3a2fda8fb..c36ba126d643 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1120,6 +1120,7 @@ class Tensor(torch._C.TensorBase): __rtruediv__ = __rdiv__ __itruediv__ = _C.TensorBase.__idiv__ + # pyrefly: ignore # bad-override __pow__ = cast( Callable[ ["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]], diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index a867ec79093d..af4deb471db2 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -657,8 +657,10 @@ def _str_intern(inp, *, tensor_contents=None): grad_fn_name = "Invalid" if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined] + # pyrefly: ignore # unbound-name grad_fn_name = type(grad_fn).__name__ if grad_fn_name == "CppFunction": + # pyrefly: ignore # unbound-name grad_fn_name = grad_fn.name().rsplit("::", 1)[-1] if grad_fn_name is not None: diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 10c0bf23f85b..37bb1837871a 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -89,6 +89,7 @@ def compile_time_strobelight_meta( skip := kwargs["skip"], int, ): + # pyrefly: ignore # unbound-name kwargs["skip"] = skip + 1 # This is not needed but we have it here to avoid having profile_compile_time diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py index 59e283528d60..6330319be8d8 100644 --- a/torch/ao/ns/fx/graph_passes.py +++ b/torch/ao/ns/fx/graph_passes.py @@ -951,6 +951,7 @@ def create_a_shadows_b( if should_log_inputs: # skip the input logger when inserting a dtype cast if isinstance(prev_node_c, Node): + # pyrefly: ignore # unbound-name prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) elif isinstance(prev_node_c, list): prev_node_c = [ @@ -959,6 +960,7 @@ def create_a_shadows_b( ] dtype_cast_node = _insert_dtype_cast_after_node( subgraph_a.start_node, + # pyrefly: ignore # unbound-name node_c, prev_node_c, gm_a, @@ -1039,7 +1041,10 @@ def create_a_shadows_b( if num_non_param_args_node_a == 2: # node_c_second_non_param_arg = node_c.args[1] node_c_second_non_param_arg = get_normalized_nth_input( - node_c, gm_b, 1 + # pyrefly: ignore # unbound-name + node_c, + gm_b, + 1, ) node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( dtype_cast_node, @@ -1047,6 +1052,7 @@ def create_a_shadows_b( subgraph_a, gm_a, gm_b, + # pyrefly: ignore # unbound-name node_c.name + "_shadow_copy_", ) env_c[node_a_shadows_c.name] = node_a_shadows_c @@ -1069,11 +1075,15 @@ def create_a_shadows_b( cur_node = node_a_shadows_c while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined] cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment] + # pyrefly: ignore # unbound-name if isinstance(input_logger, Node): + # pyrefly: ignore # unbound-name input_logger_mod = getattr(gm_b, input_logger.name) input_logger_mod.ref_node_name = cur_node.name else: + # pyrefly: ignore # unbound-name assert isinstance(input_logger, list) + # pyrefly: ignore # unbound-name for input_logger_inner in input_logger: input_logger_mod = getattr(gm_b, input_logger_inner.name) input_logger_mod.ref_node_name = cur_node.name diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index 3b2453d8cc28..cef1f5ddaee6 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -93,6 +93,7 @@ class OutputProp: ) if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined] + # pyrefly: ignore # unbound-name node.traced_result = result # pyrefly: ignore # unsupported-operation diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py index d1c5f062a6c1..ec9ca2509b53 100644 --- a/torch/ao/ns/fx/utils.py +++ b/torch/ao/ns/fx/utils.py @@ -404,7 +404,7 @@ def maybe_add_missing_fqns(results: NSResultsType) -> None: for model_name, model_results in model_name_to_results.items(): if model_name == model_name_with_fqns: continue - # pyrefly: ignore # bad-assignment + for i in range(len(model_results)): fqn = ref_model_results[i]["fqn"] model_results[i]["fqn"] = fqn diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py index d3a823543229..f5796ab04718 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_forward_time.py @@ -27,6 +27,7 @@ def run_forward(model, **batch): model(X, lS_o, lS_i) end = time.time() time_taken = end - start + # pyrefly: ignore # bad-argument-type time_list.append(time_taken) avg_time = np.mean(time_list[1:]) return avg_time diff --git a/torch/ao/pruning/_experimental/pruner/prune_functions.py b/torch/ao/pruning/_experimental/pruner/prune_functions.py index 143a1f844ba6..13cf450b6ee4 100644 --- a/torch/ao/pruning/_experimental/pruner/prune_functions.py +++ b/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -127,6 +127,7 @@ def _prune_linear_helper(linear: nn.Linear) -> Tensor: linear.out_features = linear.weight.shape[0] _remove_bias_handles(linear) + # pyrefly: ignore # unbound-name return mask @@ -185,6 +186,7 @@ def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor: conv2d.out_channels = conv2d.weight.shape[0] _remove_bias_handles(conv2d) + # pyrefly: ignore # unbound-name return mask @@ -205,6 +207,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None: new_bias = torch.zeros(conv2d_1.bias.shape) new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined] # adjusted bias that to keep in conv2d_1 + # pyrefly: ignore # unbound-name new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask] # pruned biases that are kept instead of propagated conv2d_1.bias = nn.Parameter(new_bias) diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 63cd49cb6983..4762494a45c5 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -72,7 +72,6 @@ def _find_q_dq_node_for_user( dq_node = n break if dq_node is None: - # pyrefly: ignore # bad-assignment for n in user.kwargs: if ( isinstance(n, torch.fx.Node) @@ -91,6 +90,7 @@ def _find_q_dq_node_for_user( and arg.op == "call_function" and arg.target in _QUANTIZE_OPS ): + # pyrefly: ignore # unbound-name q_node = arg return (q_node, dq_node) diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index c4148a126d1e..9ea049d7165b 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -414,5 +414,6 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager): def __enter__(self) -> None: pass + # pyrefly: ignore # bad-override def __exit__(self, *args) -> None: torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions) diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 02e06d0b932a..24148eb2bee9 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -48,14 +48,11 @@ class EventList(list): def _remove_dup_nodes(self): while True: to_delete = set() - # pyrefly: ignore # bad-assignment + for idx in range(len(self)): if ( - # pyrefly: ignore # index-error self[idx].cpu_parent is not None - # pyrefly: ignore # index-error and self[idx].cpu_parent.name == self[idx].name - # pyrefly: ignore # index-error and len(self[idx].cpu_parent.cpu_children) == 1 ): self[idx].cpu_parent.cpu_children = self[idx].cpu_children @@ -65,11 +62,11 @@ class EventList(list): to_delete.add(idx) if len(to_delete) == 0: break - # pyrefly: ignore # bad-argument-type + new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete] - # pyrefly: ignore # missing-attribute + self.clear() - # pyrefly: ignore # missing-attribute + self.extend(new_evts) def _populate_cpu_children(self): diff --git a/torch/compiler/_cache.py b/torch/compiler/_cache.py index ecc23d87fb96..8f978dd5690b 100644 --- a/torch/compiler/_cache.py +++ b/torch/compiler/_cache.py @@ -90,6 +90,7 @@ class CacheArtifactFactory: @classmethod def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact: artifact_cls = cls._get_artifact_type(artifact_type_key) + # pyrefly: ignore # bad-instantiation return artifact_cls(key, content) @classmethod @@ -97,6 +98,7 @@ class CacheArtifactFactory: cls, artifact_type_key: str, key: str, content: Any ) -> CacheArtifact: artifact_cls = cls._get_artifact_type(artifact_type_key) + # pyrefly: ignore # bad-instantiation return artifact_cls(key, artifact_cls.encode(content)) diff --git a/torch/export/_trace.py b/torch/export/_trace.py index cdd9e8e53774..8adc890ec1a8 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -377,6 +377,7 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls): nn_module_stack = { root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__), + # pyrefly: ignore # unbound-name **nn_module_stack, } node.meta["nn_module_stack"] = { @@ -525,6 +526,7 @@ def _replace_unbacked_bindings(gm: torch.fx.GraphModule) -> None: simplify=True, ) ): + # pyrefly: ignore # unbound-name node.meta["unbacked_bindings"] = unbacked_bindings @@ -662,7 +664,6 @@ def _rename_constants_nodes( if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith( const_prefix ): - # pyrefly: ignore # bad-argument-type if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants c_name = rename_constant( const_prefix + spec.arg.name[len(buffer_prefix) :] diff --git a/torch/fx/_graph_pickler.py b/torch/fx/_graph_pickler.py index 62d6fc0a36f8..0d27b3fc390d 100644 --- a/torch/fx/_graph_pickler.py +++ b/torch/fx/_graph_pickler.py @@ -332,7 +332,9 @@ class _TorchNumpyPickleData: if not (name := getattr(np, "__name__", None)): return None + # pyrefly: ignore # unbound-name assert np == getattr(importlib.import_module(mod), name) + # pyrefly: ignore # unbound-name return cls(mod, name) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 5f9e8aec4bff..c9c6412ab4e7 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1793,17 +1793,14 @@ class _ModuleStackTracer(PythonKeyTracer): self.enable_attr_proxy = False self.submodule_paths = {} for name, m in self.scope_root.named_modules(remove_duplicate=False): - # pyrefly: ignore # unsupported-operation if m in self.submodule_paths: log.info( "Shared module found between %s and %s, AttrProxy is enabled.", - # pyrefly: ignore # unsupported-operation self.submodule_paths[m], name, ) self.enable_attr_proxy = True else: - # pyrefly: ignore # unsupported-operation self.submodule_paths[m] = name self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary() @@ -2365,6 +2362,7 @@ class _MakefxTracer: ): from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts + # pyrefly: ignore # unbound-name insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx") t.recompile() # TODO: kind of a bad way to do it, should maybe figure out a better way diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 77c4c482ae91..7e025b6a9a45 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -620,11 +620,13 @@ def rebind_unbacked( ): # This is what the pattern match above is testing repacked = _sympy_cast_symbool_to_symint_guardless( + # pyrefly: ignore # unbound-name sympy.Eq(new_raw_u1, 1) ) assert repacked == raw_u1, f"{repacked} != {raw_u1}" # Cancel the to_int(to_bool(x)). This is sound because x in # [0, 1] + # pyrefly: ignore # unbound-name raw_u1 = new_raw_u1 if not isinstance(raw_u1, sympy.Symbol): @@ -1025,6 +1027,7 @@ def find_symbol_binding_fx_nodes( # NB: Prefer first occurrence of symbol for node in graph.nodes: if (s := is_symbol_binding_fx_node(node)) is not None and s not in r: + # pyrefly: ignore # unbound-name r[s] = node return r @@ -1195,10 +1198,13 @@ def _free_unbacked_symbols_with_path( and isinstance(s := expr(a), sympy.Symbol) and s in pending ): + # pyrefly: ignore # unbound-name r[s] = path if shape_env and real is not None: assert isinstance(real, (int, float)) + # pyrefly: ignore # unbound-name shape_env.set_unbacked_var_to_val(s, real) + # pyrefly: ignore # unbound-name pending.remove(s) # When an unbacked SymInt is perfectly divisible by an integer # constant, we replace it with the integer constant to improve @@ -1228,20 +1234,27 @@ def _free_unbacked_symbols_with_path( source=shape_env.var_to_sources.get(s, [None])[0], # type: ignore[union-attr] ) + # pyrefly: ignore # unbound-name unbacked = lhs if lhs in pending else rhs divisor: IntLikeType = ( + # pyrefly: ignore # unbound-name int(coeff) + # pyrefly: ignore # unbound-name if shape_env and isinstance(coeff, sympy.Integer) + # pyrefly: ignore # unbound-name else _symint_wrap(coeff) ) # TODO: DivideByKey needs to test divisibility at runtime! - # pyrefly: ignore # unsupported-operation + r[unbacked] = path + (DivideByKey(divisor),) if real is not None: assert isinstance(real, int) val = ( + # pyrefly: ignore # unbound-name real // int(coeff) + # pyrefly: ignore # unbound-name if isinstance(coeff, sympy.Integer) + # pyrefly: ignore # unbound-name else CleanDiv(real, coeff) ) if shape_env: @@ -1263,7 +1276,9 @@ def _free_unbacked_symbols_with_path( if real is not None: assert type(real) is bool if shape_env: + # pyrefly: ignore # unbound-name shape_env.set_unbacked_var_to_val(s, int(real)) + # pyrefly: ignore # unbound-name pending.remove(s.lhs) return r @@ -1339,6 +1354,7 @@ def compute_unbacked_bindings( ): if ( isinstance(old_sym, SymTypes) + # pyrefly: ignore # unbound-name and (old_s := old_sym.node.expr) != new_s ): # If old_s is not an unbacked_symbol, @@ -1348,11 +1364,15 @@ def compute_unbacked_bindings( # and the original symbol gets replaced by the backed symbol. # When this happens we just replace new_s by the old_s # because we know the value is the same. + # pyrefly: ignore # unbound-name if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s): + # pyrefly: ignore # unbound-name shape_env._rename_unbacked_to(new_s, old_s) else: + # pyrefly: ignore # unbound-name shape_env._eliminate_unbacked(new_s, old_s) elif not isinstance(old_sym, SymTypes): + # pyrefly: ignore # unbound-name shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym)) return symbol_to_path @@ -3317,6 +3337,7 @@ class DimConstraints: and str(symbol := next(iter(c["eq"].free_symbols))) == old_root ): # derived dim with root = old_root new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1 + # pyrefly: ignore # unbound-name new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1 c["eq"] = new_expr @@ -5313,7 +5334,7 @@ class ShapeEnv: ] else: assert len(input_contexts) == len(placeholders) - # pyrefly: ignore # bad-assignment + for i, (t, context) in enumerate(zip(placeholders, input_contexts)): if isinstance(t, Tensorlike): if context is None: @@ -5663,13 +5684,12 @@ class ShapeEnv: ) track_symint(property_source, ss, constraint_size[i]) else: - # pyrefly: ignore # missing-attribute for i, ss in enumerate(curr_t.size()): property_source = TensorPropertySource( src, TensorProperty.SIZE, i ) track_symint(property_source, ss, constraint_size[i]) - # pyrefly: ignore # missing-attribute + for i, ss in enumerate(curr_t.stride()): property_source = TensorPropertySource( src, TensorProperty.STRIDE, i @@ -5677,7 +5697,6 @@ class ShapeEnv: track_symint(property_source, ss, constraint_stride[i]) track_symint( TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), - # pyrefly: ignore # missing-attribute curr_t.storage_offset(), ) @@ -5723,7 +5742,6 @@ class ShapeEnv: continue if is_dim(source): - # pyrefly: ignore # missing-attribute self.dim_constraints.add_equality(source, expr) for exprs, printer, lang in zip(all_exprs, printers, langs): @@ -5877,7 +5895,6 @@ class ShapeEnv: continue expr = self.simplify(ra.expr) - # pyrefly: ignore # missing-attribute self.dim_constraints.add(expr) # 3. Every symbol must be within its value range (this handles 0/1 @@ -5894,7 +5911,6 @@ class ShapeEnv: verbose_expr = "" if r.lower not in (-sympy.oo, -int_oo): if any(is_dim(source) for source in sources): - # pyrefly: ignore # missing-attribute self.dim_constraints.add(sympy.Ge(symbol, r.lower)) # Only print lower bound in simplified mode if it is not the # default @@ -5903,7 +5919,6 @@ class ShapeEnv: verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}" if r.upper not in (sympy.oo, int_oo): if any(is_dim(source) for source in sources): - # pyrefly: ignore # missing-attribute self.dim_constraints.add(sympy.Le(symbol, r.upper)) # nontrivial upper bound is always interesting bounds.append(sympy.Le(symbol, r.upper, evaluate=False)) @@ -6152,7 +6167,6 @@ class ShapeEnv: else: bindings[-s] = -arg - # pyrefly: ignore # bad-assignment for t, arg in zip(placeholders, args): if t is None: continue @@ -7588,8 +7602,10 @@ class ShapeEnv: log.info( "oblivious_size %s -> %s (passed counterfactual)", orig_expr, + # pyrefly: ignore # unbound-name correct_hint, ) + # pyrefly: ignore # unbound-name concrete_val = correct_hint # NB: do NOT transmute into runtime assert ok = True @@ -7606,8 +7622,10 @@ class ShapeEnv: ).xreplace(self.var_to_val) ).free_symbols ): + # pyrefly: ignore # unbound-name self._log_real_tensor_propagation(orig_expr, unsound_result) transmute_into_runtime_assert = True + # pyrefly: ignore # unbound-name concrete_val = unsound_result ok = True @@ -8035,7 +8053,6 @@ def _suggest_fixes_for_data_dependent_error_non_strict( if isinstance(leaf, torch.SymInt): src_map[str(leaf.node.expr)].append(name) elif isinstance(leaf, torch.Tensor): - # pyrefly: ignore # bad-assignment for i, dim in enumerate(leaf.shape): if isinstance(dim, torch.SymInt): src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]") diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index f1b229291887..1410bbc5239c 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -407,6 +407,7 @@ class MethodDispatcher(Dispatcher): Dispatcher """ + # pyrefly: ignore # bad-override __slots__ = ("obj", "cls") @classmethod diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 81ca1402a6c7..284078b2371f 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -120,7 +120,7 @@ def _torchscript_schema_to_signature_impl( # which makes it hard to do type annotation kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment] # This renders all previous arguments to positional only - # pyrefly: ignore # bad-assignment + for idx, p in enumerate(parameters): assert p.kind == Parameter.POSITIONAL_OR_KEYWORD parameters[idx] = Parameter( @@ -129,7 +129,7 @@ def _torchscript_schema_to_signature_impl( default=p.default, annotation=p.annotation, ) - # pyrefly: ignore # missing-attribute + parameters.append( Parameter(name=name, kind=kind, default=default, annotation=arg_type) ) @@ -143,7 +143,6 @@ def _torchscript_schema_to_signature_impl( else: return_type = tuple(return_types) - # pyrefly: ignore # bad-argument-type return inspect.Signature(parameters, return_annotation=return_type) diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index 5d80d47ea2ba..41a9e371344d 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -241,7 +241,6 @@ def tensorify_python_scalars( # pyrefly: ignore # missing-attribute val = node.meta.get("val") if isinstance(val, FakeTensor): - # pyrefly: ignore # bad-assignment for dim in val.shape: if isinstance(dim, torch.SymInt): for s in dim.node.expr.free_symbols: @@ -277,6 +276,7 @@ def tensorify_python_scalars( ): transform = True try: + # pyrefly: ignore # unbound-name proxy = _sympy_interp(zf.node.expr) except NotImplementedError: transform = False @@ -303,6 +303,7 @@ def tensorify_python_scalars( args.append(a) if transform: + # pyrefly: ignore # unbound-name replacement_proxy = replacement_op(*args) # pyrefly: ignore # missing-attribute diff --git a/torch/fx/passes/fake_tensor_prop.py b/torch/fx/passes/fake_tensor_prop.py index 43dbe86c7370..48b35f5183bc 100644 --- a/torch/fx/passes/fake_tensor_prop.py +++ b/torch/fx/passes/fake_tensor_prop.py @@ -93,6 +93,7 @@ class FakeTensorProp(torch.fx.Interpreter): if (shape_env := self._mode.shape_env) and ( symbol_to_path := compute_unbacked_bindings(shape_env, result) ): + # pyrefly: ignore # unbound-name n.meta["unbacked_bindings"] = symbol_to_path return result diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index 826e998f5c9c..8fed76cc3893 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -274,7 +274,6 @@ class PassManager: logger.debug("Running pass '%s'", fn_name) try: - # pyrefly: ignore # not-callable res = fn(module) if not isinstance(res, PassResult) and not hasattr( diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 838434e4a0a1..8a147f3e0b00 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -395,21 +395,25 @@ class _MinimizerBase: report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] if self.module_exporter: if isinstance(result_key, tuple): # type: ignore[possibly-undefined] + # pyrefly: ignore # unbound-name result_key = result_key[-1] # If the result is still a tuple (happens in non-sequential mode), # we only use the first element as name. if isinstance(result_key, tuple): # type: ignore[possibly-undefined] + # pyrefly: ignore # unbound-name result_key = str(result_key[0]) # pyre-ignore[29]: not a function self.module_exporter( a_input, submodule, + # pyrefly: ignore # unbound-name result_key + "_cpu", ) # pyre-ignore[29]: not a function self.module_exporter( b_input, submodule, + # pyrefly: ignore # unbound-name result_key + "_acc", ) raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index f05982f1adea..f460622db007 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -298,10 +298,14 @@ def insert_deferred_runtime_asserts( and s not in expr_to_proxy ): with _set_node_metadata_hook(gm, _node_metadata_hook): + # pyrefly: ignore # unbound-name expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer) + # pyrefly: ignore # unbound-name log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) + # pyrefly: ignore # unbound-name match_symbol(example_value, lambda: node) + # pyrefly: ignore # unbound-name if isinstance(t := example_value, torch.Tensor): for i, s in enumerate(t.size()): match_symbol( @@ -382,6 +386,7 @@ def insert_deferred_runtime_asserts( # maybe re-reify expression, replace current node if ( + # pyrefly: ignore # unbound-name sym_expr in expr_to_proxy or ( # example value is redundant _is_intermediate_tensor_sym_call(node) @@ -400,20 +405,30 @@ def insert_deferred_runtime_asserts( nn_module_stack=node.meta.get("nn_module_stack"), ), ): + # pyrefly: ignore # unbound-name expr_to_proxy[sym_expr] = _sympy_interp( - expr_to_proxy, sym_expr + expr_to_proxy, + # pyrefly: ignore # unbound-name + sym_expr, ) # type: ignore[arg-type] # won't try DCE-ing tensor compute here hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] node.replace_all_uses_with(hash_node) gm.graph.erase_node(node) log.debug( - "CSE node %s -> %s for expr %s", node, hash_node, sym_expr + "CSE node %s -> %s for expr %s", + node, + hash_node, + # pyrefly: ignore # unbound-name + sym_expr, ) # store node in hash cons, don't delete/replace + # pyrefly: ignore # unbound-name elif sym_expr not in expr_to_proxy and not isinstance( - sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) + # pyrefly: ignore # unbound-name + sym_expr, + (sympy.Number, sympy.logic.boolalg.BooleanAtom), ): # don't hash cons primitives expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer) # type: ignore[arg-type] diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 52fbbaeaa1be..fb8bcb835ede 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -317,6 +317,7 @@ def split_module( and isinstance(s0 := val.node.expr, sympy.Symbol) and s0 not in symbol_to_node ): + # pyrefly: ignore # unbound-name symbol_to_node[val.node.expr] = node if node.op in ["placeholder", "get_attr", "output"]: diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index 97a60b06694c..0a07da522113 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -84,6 +84,7 @@ def get_source_partitions( if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and ( torch_fn := node.meta.get("torch_fn", None) ) is not None: + # pyrefly: ignore # unbound-name node_fqn, source_fn = torch_fn source_fn_name = source_fn.split(".")[1] if source_fn_name in wanted_sources: diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index eebdfad09632..686b33f44085 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -288,7 +288,7 @@ def _replace_pattern( elif isinstance(pattern, Graph): pattern_graph = pattern else: - pattern_graph = symbolic_trace(pattern).graph + pattern_graph = symbolic_trace(pattern).graph # type: ignore[arg-type] matcher = SubgraphMatcher( pattern_graph, @@ -321,7 +321,7 @@ def _replace_pattern( assert replacement_callback is not None, ( "Must provide either a replacement GraphModule or a replacement callback" ) - common_replacement_graph = None + common_replacement_graph = None # type: ignore[assignment] # As we progressively replace nodes, we'll need to keep track of how the match results should change match_changed_node: dict[Node, Node] = {} diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index e5e4ac402be2..f2a6f4a84176 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -561,7 +561,6 @@ def cat(tensors: list[list[int]], dim: int): for i in range(len(tensors)): tensor = tensors[i] if not should_skip(tensor): - # pyrefly: ignore # bad-argument-type check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i) cat_dim_size = cat_dim_size + tensor[dim] diff --git a/torch/onnx/_internal/exporter/_analysis.py b/torch/onnx/_internal/exporter/_analysis.py index 1ff8506283bd..45e87ef2fdae 100644 --- a/torch/onnx/_internal/exporter/_analysis.py +++ b/torch/onnx/_internal/exporter/_analysis.py @@ -128,6 +128,7 @@ def _format_model_info(model_info: ModelInfo) -> str: target_to_messages = {} for node, message in model_info.dispatch_failures: if str(node.target) not in target_to_messages: + # pyrefly: ignore # unsupported-operation target_to_messages[str(node.target)] = message for target, nodes in sorted( diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index c1b4ce9f4d7f..d388e44fd8c4 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -149,6 +149,7 @@ class ElementwiseTypePromotionRule(TypePromotionRule): f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})" ) + # pyrefly: ignore # bad-override def __eq__(self, other: object, /) -> bool: if not isinstance(other, ElementwiseTypePromotionRule): return False @@ -265,6 +266,7 @@ class ReductionTypePromotionRule(TypePromotionRule): def __repr__(self): return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})" + # pyrefly: ignore # bad-override def __eq__(self, other: object, /) -> bool: if not isinstance(other, ElementwiseTypePromotionRule): return False diff --git a/torch/onnx/_internal/torchscript_exporter/jit_utils.py b/torch/onnx/_internal/torchscript_exporter/jit_utils.py index 6c00b6a9c8c4..e0bbe92e0e88 100644 --- a/torch/onnx/_internal/torchscript_exporter/jit_utils.py +++ b/torch/onnx/_internal/torchscript_exporter/jit_utils.py @@ -298,9 +298,12 @@ def _create_node( for key, value in sorted(attributes.items()): if key in _SKIP_NODE_ATTRIBUTES: continue + # pyrefly: ignore # unbound-name _add_attribute(node, key, value, aten=aten) if shape_inference: + # pyrefly: ignore # unbound-name _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version) + # pyrefly: ignore # unbound-name return node diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py index 858e81766446..cbba5d2e61cb 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset11.py @@ -219,7 +219,6 @@ def index_put( if len(indices_list) > 1: for idx_ in range(len(indices_list)): if symbolic_helper._is_bool(indices_list[idx_]): - # pyrefly: ignore # unsupported-operation indices_list[idx_] = g.op("NonZero", indices_list[idx_]) index = indices_list[0] diff --git a/torch/optim/adam.py b/torch/optim/adam.py index fab8e18c4310..f1602a820cec 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -698,7 +698,6 @@ def _multi_tensor_adam( device_exp_avgs, device_grads, cast(float, 1 - device_beta1) ) - # pyrefly: ignore # no-matching-overload torch._foreach_mul_(device_exp_avg_sqs, beta2) # Due to the strictness of the _foreach_addcmul API, we can't have a single diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index 52c317f8bd65..cf2ad5a2f35a 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -115,11 +115,14 @@ def _strong_wolfe( t = _cubic_interpolate( # pyrefly: ignore # index-error bracket[0], + # pyrefly: ignore # unbound-name bracket_f[0], bracket_gtd[0], # type: ignore[possibly-undefined] # pyrefly: ignore # index-error bracket[1], + # pyrefly: ignore # unbound-name bracket_f[1], + # pyrefly: ignore # unbound-name bracket_gtd[1], ) @@ -130,14 +133,20 @@ def _strong_wolfe( # + `t` is at one of the boundary, # we will move `t` to a position which is `0.1 * len(bracket)` # away from the nearest boundary point. + # pyrefly: ignore # unbound-name eps = 0.1 * (max(bracket) - min(bracket)) + # pyrefly: ignore # unbound-name if min(max(bracket) - t, t - min(bracket)) < eps: # interpolation close to boundary + # pyrefly: ignore # unbound-name if insuf_progress or t >= max(bracket) or t <= min(bracket): # evaluate at 0.1 away from boundary + # pyrefly: ignore # unbound-name if abs(t - max(bracket)) < abs(t - min(bracket)): + # pyrefly: ignore # unbound-name t = max(bracket) - eps else: + # pyrefly: ignore # unbound-name t = min(bracket) + eps insuf_progress = False else: @@ -151,13 +160,17 @@ def _strong_wolfe( gtd_new = g_new.dot(d) ls_iter += 1 + # pyrefly: ignore # unbound-name if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: # Armijo condition not satisfied or not lower than lowest point # pyrefly: ignore # unsupported-operation bracket[high_pos] = t + # pyrefly: ignore # unbound-name bracket_f[high_pos] = f_new bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] + # pyrefly: ignore # unbound-name bracket_gtd[high_pos] = gtd_new + # pyrefly: ignore # unbound-name low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) else: if abs(gtd_new) <= -c2 * gtd: @@ -168,19 +181,24 @@ def _strong_wolfe( # old high becomes new low # pyrefly: ignore # unsupported-operation bracket[high_pos] = bracket[low_pos] + # pyrefly: ignore # unbound-name bracket_f[high_pos] = bracket_f[low_pos] bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined] + # pyrefly: ignore # unbound-name bracket_gtd[high_pos] = bracket_gtd[low_pos] # new point becomes new low # pyrefly: ignore # unsupported-operation bracket[low_pos] = t + # pyrefly: ignore # unbound-name bracket_f[low_pos] = f_new bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] + # pyrefly: ignore # unbound-name bracket_gtd[low_pos] = gtd_new # return stuff t = bracket[low_pos] # type: ignore[possibly-undefined] + # pyrefly: ignore # unbound-name f_new = bracket_f[low_pos] g_new = bracket_g[low_pos] # type: ignore[possibly-undefined] return f_new, g_new, t, ls_func_evals diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 1c1ccab7f40a..3a6bc296d70d 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -420,6 +420,7 @@ class LambdaLR(LRScheduler): for idx, fn in enumerate(self.lr_lambdas): if not isinstance(fn, types.FunctionType): + # pyrefly: ignore # unsupported-operation state_dict["lr_lambdas"][idx] = fn.__dict__.copy() return state_dict @@ -539,6 +540,7 @@ class MultiplicativeLR(LRScheduler): for idx, fn in enumerate(self.lr_lambdas): if not isinstance(fn, types.FunctionType): + # pyrefly: ignore # unsupported-operation state_dict["lr_lambdas"][idx] = fn.__dict__.copy() return state_dict @@ -1215,6 +1217,7 @@ class SequentialLR(LRScheduler): state_dict["_schedulers"] = [None] * len(self._schedulers) for idx, s in enumerate(self._schedulers): + # pyrefly: ignore # unsupported-operation state_dict["_schedulers"][idx] = s.state_dict() return state_dict @@ -1557,6 +1560,7 @@ class ChainedScheduler(LRScheduler): state_dict["_schedulers"] = [None] * len(self._schedulers) for idx, s in enumerate(self._schedulers): + # pyrefly: ignore # unsupported-operation state_dict["_schedulers"][idx] = s.state_dict() return state_dict diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 4fb990c9e2d6..d7d4b1f21c53 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -337,7 +337,6 @@ def _single_tensor_sgd( if not torch.jit.is_scripting(): lr = _to_scalar(lr) - # pyrefly: ignore # bad-assignment for i, param in enumerate(params): grad = grads[i] if not maximize else -grads[i] @@ -433,12 +432,10 @@ def _multi_tensor_sgd( all_states_with_momentum_buffer = True for i in range(len(device_momentum_buffer_list)): - # pyrefly: ignore # index-error if device_momentum_buffer_list[i] is None: all_states_with_momentum_buffer = False break else: - # pyrefly: ignore # index-error bufs.append(cast(Tensor, device_momentum_buffer_list[i])) if all_states_with_momentum_buffer: @@ -446,15 +443,13 @@ def _multi_tensor_sgd( torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) else: bufs = [] - # pyrefly: ignore # bad-assignment + for i in range(len(device_momentum_buffer_list)): - # pyrefly: ignore # index-error if device_momentum_buffer_list[i] is None: buf = device_momentum_buffer_list[i] = momentum_buffer_list[ indices[i] ] = device_grads[i].detach().clone() else: - # pyrefly: ignore # index-error buf = cast(Tensor, device_momentum_buffer_list[i]) buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 84496b087f54..1fa07c90fde1 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -672,7 +672,7 @@ class MemoryProfile: output: list[tuple[int, Action, KeyAndID, int]] = [] allocation_times: dict[tuple[TensorKey, bool], int] = {} live_unknown: dict[tuple[int, torch.device], Literal[True]] = {} - # pyrefly: ignore # bad-assignment + for event in self._op_tree.dfs(): if event.typed[0] == _EventType.Allocation: alloc_fields = event.typed[1] @@ -774,14 +774,12 @@ class MemoryProfile: for key, (_, version) in node.inputs.items() if self._categories.get(key, version) in (Category.GRADIENT, Category.PARAMETER) - # pyrefly: ignore # unsupported-operation or key.id in depends_on_gradient ) if ids: - # pyrefly: ignore # missing-attribute depends_on_gradient.update(ids) - # pyrefly: ignore # missing-attribute + depends_on_gradient.update(key.id for key in node.outputs) # We are guaranteed to exit because there is a finite set of @@ -790,7 +788,6 @@ class MemoryProfile: # once to fold the first step into that loop, and a third time # where no new elements are added. if len(depends_on_gradient) == start_size: - # pyrefly: ignore # bad-return return depends_on_gradient def _set_gradients_and_temporaries(self) -> None: diff --git a/torch/sparse/_semi_structured_conversions.py b/torch/sparse/_semi_structured_conversions.py index f9b1b0899f87..c98205f56707 100644 --- a/torch/sparse/_semi_structured_conversions.py +++ b/torch/sparse/_semi_structured_conversions.py @@ -140,6 +140,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): if dense.dtype != torch.float: sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + # pyrefly: ignore # unbound-name sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) else: @@ -172,6 +173,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): meta_offsets = _calculate_meta_reordering_scatter_offsets( m, meta_ncols, meta_dtype, device ) + # pyrefly: ignore # unbound-name meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) return (sparse, meta_reordered.view(m, meta_ncols)) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index a9d4d7d8b616..942e5e8dca3f 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -385,7 +385,7 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): g1 = c_offsets[r + 1] for g in range(g0, g1): p, q = pq[g] - # pyrefly: ignore # unsupported-operation + accumulators[r] += blocks[p] @ others[q] else: _scatter_mm2(blocks, others, c_offsets, pq, accumulators) diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 04ef1aa2e44a..6c4506f1a8a9 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -1219,6 +1219,7 @@ def originate_pairs( else: for pair_type in pair_types: try: + # pyrefly: ignore # bad-instantiation return [pair_type(actual, expected, id=id, **options)] # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the # inputs. Thus, we try the next pair type. diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index af2af999158a..80c81507751b 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -95,7 +95,7 @@ from torch.utils._import_utils import _check_module_exists import torch.utils._pytree as pytree from torch.utils import cpp_extension try: - import pytest + import pytest # type: ignore[import-not-found] has_pytest = True except ImportError: has_pytest = False diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index 10fc16a13c5d..65e0674f3d48 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -117,6 +117,7 @@ def context_decorator(ctx, func): @functools.wraps(func) def decorate_context(*args, **kwargs): + # pyrefly: ignore # bad-context-manager with ctx_factory(): return func(*args, **kwargs) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 708a588fa8d0..47f9ca084d79 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -41,7 +41,10 @@ if not python_pytree._cxx_pytree_dynamo_traceable: ) +# pyrefly: ignore # import-error import optree + +# pyrefly: ignore # import-error from optree import PyTreeSpec as TreeSpec # direct import for type annotations @@ -706,6 +709,7 @@ def tree_map_only( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> PyTree: + # pyrefly: ignore # no-matching-overload return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) @@ -766,6 +770,7 @@ def tree_map_only_( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> PyTree: + # pyrefly: ignore # no-matching-overload return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) @@ -1079,6 +1084,7 @@ def key_get(obj: Any, kp: KeyPath) -> Any: with python_pytree._NODE_REGISTRY_LOCK: + # pyrefly: ignore # bad-assignment python_pytree._cxx_pytree_imported = True args, kwargs = (), {} # type: ignore[var-annotated] for args, kwargs in python_pytree._cxx_pytree_pending_imports: diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index d840139b090c..63c5e4f17e49 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -152,6 +152,7 @@ class DebugMode(TorchDispatchMode): super().__enter__() return self + # pyrefly: ignore # bad-override def __exit__(self, *args): super().__exit__(*args) if self.record_torchfunction: diff --git a/torch/utils/_device.py b/torch/utils/_device.py index de3ee4a9e344..8a2f409c728c 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -60,6 +60,7 @@ def _device_constructors(): # NB: This is directly called from C++ in torch/csrc/Device.cpp class DeviceContext(TorchFunctionMode): def __init__(self, device): + # pyrefly: ignore # read-only self.device = torch.device(device) def __enter__(self): diff --git a/torch/utils/_functools.py b/torch/utils/_functools.py index e862953a908d..37f0a1d17a22 100644 --- a/torch/utils/_functools.py +++ b/torch/utils/_functools.py @@ -35,10 +35,12 @@ def cache_method( if not (cache := getattr(self, cache_name, None)): cache = {} setattr(self, cache_name, cache) + # pyrefly: ignore # unbound-name cached_value = cache.get(args, _cache_sentinel) if cached_value is not _cache_sentinel: return cached_value value = f(self, *args, **kwargs) + # pyrefly: ignore # unbound-name cache[args] = value return value diff --git a/torch/utils/_ordered_set.py b/torch/utils/_ordered_set.py index b2a69fc0ff34..cea8ea684d39 100644 --- a/torch/utils/_ordered_set.py +++ b/torch/utils/_ordered_set.py @@ -158,6 +158,7 @@ class OrderedSet(MutableSet[T], Reversible[T]): def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]: # MutableSet impl will iterate over other, iter over smaller of two sets if isinstance(other, OrderedSet) and len(self) < len(other): + # pyrefly: ignore # unsupported-operation, bad-return return other & self return cast(OrderedSet[T], super().__and__(other)) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 07f9b09f9911..53ba046b3ef5 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -708,6 +708,7 @@ class structseq(tuple[_T_co, ...]): def __new__( cls: type[Self], sequence: Iterable[_T_co], + # pyrefly: ignore # bad-function-definition dict: dict[str, Any] = ..., ) -> Self: raise NotImplementedError @@ -754,6 +755,7 @@ def _tuple_flatten_with_keys( d: tuple[T, ...], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _tuple_flatten(d) + # pyrefly: ignore # bad-return return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -767,6 +769,7 @@ def _list_flatten(d: list[T]) -> tuple[list[T], Context]: def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _list_flatten(d) + # pyrefly: ignore # bad-return return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -782,6 +785,7 @@ def _dict_flatten_with_keys( d: dict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _dict_flatten(d) + # pyrefly: ignore # bad-return return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -797,6 +801,7 @@ def _namedtuple_flatten_with_keys( d: NamedTuple, ) -> tuple[list[tuple[KeyEntry, Any]], Context]: values, context = _namedtuple_flatten(d) + # pyrefly: ignore # bad-return return ( [(GetAttrKey(field), v) for field, v in zip(context._fields, values)], context, @@ -846,6 +851,7 @@ def _ordereddict_flatten_with_keys( d: OrderedDict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _ordereddict_flatten(d) + # pyrefly: ignore # bad-return return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -870,6 +876,7 @@ def _defaultdict_flatten_with_keys( ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _defaultdict_flatten(d) _, dict_context = context + # pyrefly: ignore # bad-return return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context @@ -918,6 +925,7 @@ def _deque_flatten_with_keys( d: deque[T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _deque_flatten(d) + # pyrefly: ignore # bad-return return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -1547,6 +1555,7 @@ def tree_map_only( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> PyTree: + # pyrefly: ignore # no-matching-overload return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) @@ -1607,6 +1616,7 @@ def tree_map_only_( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> PyTree: + # pyrefly: ignore # no-matching-overload return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) @@ -1819,6 +1829,7 @@ def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]: for attr in classname.split("."): enum_cls = getattr(enum_cls, attr) enum_cls = cast(type[Enum], enum_cls) + # pyrefly: ignore # unsupported-operation return enum_cls[obj["name"]] return obj diff --git a/torch/utils/_strobelight/cli_function_profiler.py b/torch/utils/_strobelight/cli_function_profiler.py index ef1a6edc682e..7825f784e2f3 100644 --- a/torch/utils/_strobelight/cli_function_profiler.py +++ b/torch/utils/_strobelight/cli_function_profiler.py @@ -305,6 +305,7 @@ def strobelight( ) -> Callable[_P, Optional[_R]]: @functools.wraps(work_function) def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: + # pyrefly: ignore # bad-argument-type return profiler.profile(work_function, *args, **kwargs) return wrapper_function diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index b94dd7610689..c2f115ace561 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -105,6 +105,7 @@ def _keep_float( ) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]: @functools.wraps(f) def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: + # pyrefly: ignore # bad-argument-type r: Union[_T, sympy.Float] = f(*args) if any(isinstance(a, sympy.Float) for a in args) and not isinstance( r, sympy.Float @@ -112,6 +113,7 @@ def _keep_float( r = sympy.Float(float(r)) return r + # pyrefly: ignore # bad-return return inner @@ -198,10 +200,12 @@ class FloorDiv(sympy.Function): @property def base(self) -> sympy.Basic: + # pyrefly: ignore # missing-attribute return self.args[0] @property def divisor(self) -> sympy.Basic: + # pyrefly: ignore # missing-attribute return self.args[1] def _sympystr(self, printer: sympy.printing.StrPrinter) -> str: @@ -370,6 +374,7 @@ class ModularIndexing(sympy.Function): return None def _eval_is_nonnegative(self) -> Optional[bool]: + # pyrefly: ignore # missing-attribute p, q = self.args[:2] return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] @@ -450,6 +455,7 @@ class PythonMod(sympy.Function): # - floor(p / q) = 0 # - p % q = p - floor(p / q) * q = p less = p < q + # pyrefly: ignore # missing-attribute if less.is_Boolean and bool(less) and r.is_positive: return p @@ -466,8 +472,11 @@ class PythonMod(sympy.Function): return True if self.args[1].is_negative else None # type: ignore[attr-defined] def _ccode(self, printer): + # pyrefly: ignore # missing-attribute p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) + # pyrefly: ignore # missing-attribute q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5) + # pyrefly: ignore # missing-attribute abs_q = str(q) if self.args[1].is_positive else f"abs({q})" return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}" @@ -548,6 +557,7 @@ class CeilToInt(sympy.Function): return sympy.Integer(math.ceil(float(number))) def _ccode(self, printer): + # pyrefly: ignore # missing-attribute number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5) return f"ceil({number})" @@ -818,6 +828,7 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc] if not cond: return ai.func(*[do(i, a) for i in ai.args], evaluate=False) if isinstance(ai, cls): + # pyrefly: ignore # missing-attribute return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False) return a @@ -995,6 +1006,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc] return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] def _eval_is_negative(self): # type:ignore[override] + # pyrefly: ignore # missing-attribute return fuzzy_and(a.is_negative for a in self.args) @@ -1013,6 +1025,7 @@ class Min(MinMaxBase, Application): # type: ignore[misc] return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] def _eval_is_negative(self): # type:ignore[override] + # pyrefly: ignore # missing-attribute return fuzzy_or(a.is_negative for a in self.args) @@ -1150,7 +1163,9 @@ class IntTrueDiv(sympy.Function): return sympy.Float(int(base) / int(divisor)) def _ccode(self, printer): + # pyrefly: ignore # missing-attribute base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) + # pyrefly: ignore # missing-attribute divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5) return f"((int){base}/(int){divisor})" @@ -1310,9 +1325,11 @@ class Identity(sympy.Function): precedence = 10 def __repr__(self): # type: ignore[override] + # pyrefly: ignore # missing-attribute return f"Identity({self.args[0]})" def _eval_is_real(self): + # pyrefly: ignore # missing-attribute return self.args[0].is_real def _eval_is_integer(self): @@ -1320,12 +1337,15 @@ class Identity(sympy.Function): def _eval_expand_identity(self, **hints): # Removes the identity op. + # pyrefly: ignore # missing-attribute return self.args[0] def __int__(self) -> int: + # pyrefly: ignore # missing-attribute return int(self.args[0]) def __float__(self) -> float: + # pyrefly: ignore # missing-attribute return float(self.args[0]) diff --git a/torch/utils/_sympy/numbers.py b/torch/utils/_sympy/numbers.py index d02b9879cad2..01aee8b29f10 100644 --- a/torch/utils/_sympy/numbers.py +++ b/torch/utils/_sympy/numbers.py @@ -9,6 +9,7 @@ from sympy.core.parameters import global_parameters from sympy.core.singleton import S, Singleton +# pyrefly: ignore # invalid-inheritance class IntInfinity(Number, metaclass=Singleton): r"""Positive integer infinite quantity. @@ -203,6 +204,7 @@ class IntInfinity(Number, metaclass=Singleton): int_oo = S.IntInfinity +# pyrefly: ignore # invalid-inheritance class NegativeIntInfinity(Number, metaclass=Singleton): """Negative integer infinite quantity. diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index acfcc596bd49..6f78bc3e12d3 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -66,6 +66,7 @@ class ExprPrinter(StrPrinter): # NB: this pow by natural, you should never have used builtin sympy.pow # for FloatPow, and a symbolic exponent should be PowByNatural. These # means exp is guaranteed to be integer. + # pyrefly: ignore # bad-override def _print_Pow(self, expr: sympy.Expr) -> str: base, exp = expr.args assert exp == int(exp), exp diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 8c960e92f223..05dd8d3eef61 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -175,6 +175,7 @@ class ReferenceAnalysis: @staticmethod def pow(a, b): + # pyrefly: ignore # bad-argument-type return _keep_float(FloatPow)(a, b) @staticmethod diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 2137d21662f3..4ff0e063cc26 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -123,7 +123,9 @@ AllFn2 = Union[ExprFn2, BoolFn2] class ValueRanges(Generic[_T]): if TYPE_CHECKING: # ruff doesn't understand circular references but mypy does + # pyrefly: ignore # unbound-name ExprVR = ValueRanges[sympy.Expr] # noqa: F821 + # pyrefly: ignore # unbound-name BoolVR = ValueRanges[SympyBoolean] # noqa: F821 AllVR = Union[ExprVR, BoolVR] @@ -464,6 +466,7 @@ class SymPyValueRangeAnalysis: @staticmethod def to_dtype(a, dtype, src_dtype=None): if dtype == torch.float64: + # pyrefly: ignore # bad-argument-type return ValueRanges.increasing_map(a, ToFloat) elif dtype == torch.bool: return ValueRanges.unknown_bool() @@ -473,6 +476,7 @@ class SymPyValueRangeAnalysis: @staticmethod def trunc_to_int(a, dtype): + # pyrefly: ignore # bad-argument-type return ValueRanges.increasing_map(a, TruncToInt) @staticmethod @@ -621,7 +625,10 @@ class SymPyValueRangeAnalysis: return ValueRanges.unknown() else: return ValueRanges.coordinatewise_monotone_map( - a, b, _keep_float(IntTrueDiv) + a, + b, + # pyrefly: ignore # bad-argument-type + _keep_float(IntTrueDiv), ) @staticmethod @@ -634,7 +641,10 @@ class SymPyValueRangeAnalysis: return ValueRanges.unknown() else: return ValueRanges.coordinatewise_monotone_map( - a, b, _keep_float(FloatTrueDiv) + a, + b, + # pyrefly: ignore # bad-argument-type + _keep_float(FloatTrueDiv), ) @staticmethod @@ -713,6 +723,7 @@ class SymPyValueRangeAnalysis: # We should know that b >= 0 but we may have forgotten this fact due # to replacements, so don't assert it, but DO clamp it to prevent # degenerate problems + # pyrefly: ignore # no-matching-overload return ValueRanges.coordinatewise_increasing_map( a, b & ValueRanges(0, int_oo), PowByNatural ) @@ -879,6 +890,7 @@ class SymPyValueRangeAnalysis: @classmethod def round_to_int(cls, number, dtype): + # pyrefly: ignore # bad-argument-type return ValueRanges.increasing_map(number, RoundToInt) # It's used in some models on symints @@ -992,6 +1004,7 @@ class SymPyValueRangeAnalysis: @staticmethod def trunc(x): + # pyrefly: ignore # bad-argument-type return ValueRanges.increasing_map(x, TruncToFloat) diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index b54bd25f1016..d034f22b1e69 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -202,6 +202,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) - Args: device (int, optional): if specified, all parameters will be copied to that device """ + # pyrefly: ignore # missing-attribute return self._apply(lambda t: getattr(t, custom_backend_name)(device)) _check_register_once(torch.nn.Module, custom_backend_name) diff --git a/torch/utils/benchmark/examples/sparse/compare.py b/torch/utils/benchmark/examples/sparse/compare.py index 640912e0167e..91e30e68054a 100644 --- a/torch/utils/benchmark/examples/sparse/compare.py +++ b/torch/utils/benchmark/examples/sparse/compare.py @@ -63,6 +63,7 @@ def generate_coo_data(size, sparse_dim, nnz, dtype, device): indices = torch.rand(sparse_dim, nnz, device=device) indices.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(indices)) indices = indices.to(torch.long) + # pyrefly: ignore # no-matching-overload values = torch.rand([nnz, ], dtype=dtype, device=device) return indices, values diff --git a/torch/utils/benchmark/utils/compile.py b/torch/utils/benchmark/utils/compile.py index fbca93d221bd..9127b14c99b3 100644 --- a/torch/utils/benchmark/utils/compile.py +++ b/torch/utils/benchmark/utils/compile.py @@ -15,6 +15,7 @@ _warned_tensor_cores = False _default_float_32_precision = torch.get_float32_matmul_precision() try: + from tabulate import tabulate HAS_TABULATE = True @@ -169,6 +170,7 @@ if HAS_TABULATE: _disable_tensor_cores() table.append([ ("Training" if optimizer else "Inference"), + # pyrefly: ignore # redundant-condition backend if backend else "-", mode if mode is not None else "-", f"{compilation_time} ms " if compilation_time else "-", @@ -189,4 +191,5 @@ if HAS_TABULATE: ]) + # pyrefly: ignore # not-callable return tabulate(table, headers=field_names, tablefmt="github") diff --git a/torch/utils/benchmark/utils/cpp_jit.py b/torch/utils/benchmark/utils/cpp_jit.py index b7aec25f6a76..00b4205b8206 100644 --- a/torch/utils/benchmark/utils/cpp_jit.py +++ b/torch/utils/benchmark/utils/cpp_jit.py @@ -35,6 +35,7 @@ def _get_build_root() -> str: global _BUILD_ROOT if _BUILD_ROOT is None: _BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build") + # pyrefly: ignore # missing-argument atexit.register(shutil.rmtree, _BUILD_ROOT) return _BUILD_ROOT diff --git a/torch/utils/benchmark/utils/sparse_fuzzer.py b/torch/utils/benchmark/utils/sparse_fuzzer.py index 498f94ca26f1..42d5dbdbac0d 100644 --- a/torch/utils/benchmark/utils/sparse_fuzzer.py +++ b/torch/utils/benchmark/utils/sparse_fuzzer.py @@ -91,6 +91,7 @@ class FuzzedSparseTensor(FuzzedTensor): return x def _make_tensor(self, params, state): + # pyrefly: ignore # missing-attribute size, _, _ = self._get_size_and_steps(params) density = params['density'] nnz = math.ceil(sum(size) * density) @@ -99,8 +100,10 @@ class FuzzedSparseTensor(FuzzedTensor): is_coalesced = params['coalesced'] sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size) sparse_dim = min(sparse_dim, len(size)) + # pyrefly: ignore # missing-attribute tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced) + # pyrefly: ignore # missing-attribute if self._cuda: tensor = tensor.cuda() sparse_dim = tensor.sparse_dim() @@ -116,6 +119,7 @@ class FuzzedSparseTensor(FuzzedTensor): "sparse_dim": sparse_dim, "dense_dim": dense_dim, "is_hybrid": is_hybrid, + # pyrefly: ignore # missing-attribute "dtype": str(self._dtype), } return tensor, properties diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py index 9590b88ed153..b9c1b65c0599 100644 --- a/torch/utils/benchmark/utils/timer.py +++ b/torch/utils/benchmark/utils/timer.py @@ -233,6 +233,7 @@ class Timer: setup = textwrap.dedent(setup) setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip() + # pyrefly: ignore # bad-instantiation self._timer = self._timer_cls( stmt=stmt, setup=setup, diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py index 6777d7ba41b4..b821d8bef509 100644 --- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py +++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py @@ -448,11 +448,13 @@ class GlobalsBridge: load_lines = [] for name, wrapped_value in self._globals.items(): if wrapped_value.setup is not None: + # pyrefly: ignore # bad-argument-type load_lines.append(textwrap.dedent(wrapped_value.setup)) if wrapped_value.serialization == Serialization.PICKLE: path = os.path.join(self._data_dir, f"{name}.pkl") load_lines.append( + # pyrefly: ignore # bad-argument-type f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)") with open(path, "wb") as f: pickle.dump(wrapped_value.value, f) @@ -462,11 +464,13 @@ class GlobalsBridge: # TODO: Figure out if we can use torch.serialization.add_safe_globals here # Using weights_only=False after the change in # https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573 + # pyrefly: ignore # bad-argument-type load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)") torch.save(wrapped_value.value, path) elif wrapped_value.serialization == Serialization.TORCH_JIT: path = os.path.join(self._data_dir, f"{name}.pt") + # pyrefly: ignore # bad-argument-type load_lines.append(f"{name} = torch.jit.load({repr(path)})") with open(path, "wb") as f: torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call] diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 4358cc4567fd..cee0b82cc793 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -222,6 +222,7 @@ def _get_autocast_kwargs(device_type="cuda"): class CheckpointFunction(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function @@ -784,6 +785,7 @@ class _Holder: class _NoopSaveInputs(torch.autograd.Function): @staticmethod + # pyrefly: ignore # bad-override def forward(*args): return torch.empty((0,)) @@ -1006,6 +1008,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint def logging_mode(): with LoggingTensorMode(), \ capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: + # pyrefly: ignore # bad-assignment self.logs, self.tbs = logs_and_tb yield logs_and_tb return logging_mode() diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index fe3040e2a70e..47eb183f4ee6 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -787,6 +787,7 @@ class BuildExtension(build_ext): # Use absolute path for output_dir so that the object file paths # (`objects`) get generated with absolute paths. + # pyrefly: ignore # no-matching-overload output_dir = os.path.abspath(output_dir) # See Note [Absolute include_dirs] @@ -977,6 +978,7 @@ class BuildExtension(build_ext): is_standalone=False): if not self.compiler.initialized: self.compiler.initialize() + # pyrefly: ignore # no-matching-overload output_dir = os.path.abspath(output_dir) # Note [Absolute include_dirs] @@ -1528,6 +1530,7 @@ def include_paths(device_type: str = "cpu", torch_include_dirs=True) -> list[str # Support CUDA_INC_PATH env variable supported by CMake files if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \ cuda_inc_path != '/usr/include': + # pyrefly: ignore # unbound-name paths.append(cuda_inc_path) if CUDNN_HOME is not None: paths.append(os.path.join(CUDNN_HOME, 'include')) @@ -2569,6 +2572,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]: def _get_vc_env(vc_arch: str) -> dict[str, str]: try: from setuptools import distutils # type: ignore[attr-defined] + # pyrefly: ignore # missing-attribute return distutils._msvccompiler._get_vc_env(vc_arch) except AttributeError: try: diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index e5da3b3ba2b6..b9a04644f331 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -204,6 +204,7 @@ def collate( # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) + # pyrefly: ignore # not-iterable if not all(len(elem) == elem_size for elem in it): raise RuntimeError("each element in list of batch should be of equal size") transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index b53c7aef9596..c0a9416c45fe 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -70,6 +70,7 @@ def pin_memory(data, device=None): return clone else: return type(data)( + # pyrefly: ignore # bad-argument-count {k: pin_memory(sample, device) for k, sample in data.items()} ) # type: ignore[call-arg] except TypeError: diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index f48f726e7955..e7466b02c4c6 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -674,6 +674,7 @@ class _BaseDataLoaderIter: # Set pin memory device based on the current accelerator. self._pin_memory_device = ( + # pyrefly: ignore # unbound-name acc.type if self._pin_memory and (acc := torch.accelerator.current_accelerator()) is not None diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index d3ae5b4e18f4..528750157398 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -265,6 +265,7 @@ class _DataPipeType: # Default type for DataPipe without annotation _T_co = TypeVar("_T_co", covariant=True) +# pyrefly: ignore # invalid-annotation _DEFAULT_TYPE = _DataPipeType(Generic[_T_co]) @@ -283,6 +284,7 @@ class _DataPipeMeta(GenericMeta): return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] # TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now. + # pyrefly: ignore # no-access cls.__origin__ = None if "type" in namespace: return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index 5884bd15ae08..f5a4ebaf2703 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -80,6 +80,7 @@ class Capture: def _ops_str(self): res = "" + # pyrefly: ignore # not-iterable for op in self.ctx["operations"]: if len(res) > 0: res += "\n" @@ -89,6 +90,7 @@ class Capture: def __getstate__(self): # TODO(VitalyFedyunin): Currently can't pickle (why?) self.ctx["schema_df"] = None + # pyrefly: ignore # not-iterable for var in self.ctx["variables"]: var.calculated_value = None state = {} @@ -112,11 +114,13 @@ class Capture: return CaptureGetItem(self, key, ctx=self.ctx) def __setitem__(self, key, value): + # pyrefly: ignore # missing-attribute self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx)) def __add__(self, add_val): res = CaptureAdd(self, add_val, ctx=self.ctx) var = CaptureVariable(res, ctx=self.ctx) + # pyrefly: ignore # missing-attribute self.ctx["operations"].append( CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) ) @@ -125,6 +129,7 @@ class Capture: def __sub__(self, add_val): res = CaptureSub(self, add_val, ctx=self.ctx) var = CaptureVariable(res, ctx=self.ctx) + # pyrefly: ignore # missing-attribute self.ctx["operations"].append( CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) ) @@ -134,15 +139,19 @@ class Capture: res = CaptureMul(self, add_val, ctx=self.ctx) var = CaptureVariable(res, ctx=self.ctx) t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) + # pyrefly: ignore # missing-attribute self.ctx["operations"].append(t) return var def _is_context_empty(self): + # pyrefly: ignore # bad-argument-type return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0 def apply_ops_2(self, dataframe): # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) + # pyrefly: ignore # unsupported-operation self.ctx["variables"][0].calculated_value = dataframe + # pyrefly: ignore # not-iterable for op in self.ctx["operations"]: op.execute() @@ -175,6 +184,7 @@ class Capture: res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs) var = CaptureVariable(None, ctx=self.ctx) t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res) + # pyrefly: ignore # missing-attribute self.ctx["operations"].append(t) return var @@ -273,7 +283,9 @@ class CaptureVariable(Capture): def apply_ops(self, dataframe): # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) + # pyrefly: ignore # unsupported-operation self.ctx["variables"][0].calculated_value = dataframe + # pyrefly: ignore # not-iterable for op in self.ctx["operations"]: op.execute() return self.calculated_value @@ -373,6 +385,7 @@ def get_val(capture): class CaptureInitial(CaptureVariable): def __init__(self, schema_df=None): + # pyrefly: ignore # bad-assignment new_ctx: dict[str, list[Any]] = { "operations": [], "variables": [], @@ -388,6 +401,7 @@ class CaptureDataFrame(CaptureInitial): class CaptureDataFrameWithDataPipeOps(CaptureDataFrame): def as_datapipe(self): + # pyrefly: ignore # unsupported-operation return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self) def raw_iterator(self): diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py index c9b89d6437aa..2bf0dda77752 100644 --- a/torch/utils/data/datapipes/dataframe/datapipes.py +++ b/torch/utils/data/datapipes/dataframe/datapipes.py @@ -92,6 +92,7 @@ class FilterDataFramesPipe(DFIterDataPipe): size = None all_buffer = [] filter_res = [] + # pyrefly: ignore # bad-assignment for df in self.source_datapipe: if size is None: size = len(df.index) diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index 54b0f7510923..9131b6284374 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -135,6 +135,7 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): _fast_forward_iterator: Optional[Iterator] = None def __iter__(self) -> Iterator[_T_co]: + # pyrefly: ignore # bad-return return self def __getattr__(self, attribute_name): @@ -379,6 +380,7 @@ class _DataPipeSerializationWrapper: value = pickle.dumps(self._datapipe) except Exception: if HAS_DILL: + # pyrefly: ignore # missing-attribute value = dill.dumps(self._datapipe) use_dill = True else: @@ -388,6 +390,7 @@ class _DataPipeSerializationWrapper: def __setstate__(self, state): value, use_dill = state if use_dill: + # pyrefly: ignore # missing-attribute self._datapipe = dill.loads(value) else: self._datapipe = pickle.loads(value) @@ -404,6 +407,7 @@ class _DataPipeSerializationWrapper: class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe): def __init__(self, datapipe: IterDataPipe[_T_co]): super().__init__(datapipe) + # pyrefly: ignore # invalid-type-var self._datapipe_iter: Optional[Iterator[_T_co]] = None def __iter__(self) -> "_IterDataPipeSerializationWrapper": diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 335b8c888668..bfff0d19f4cf 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -118,6 +118,7 @@ class MapperIterDataPipe(IterDataPipe[_T_co]): for idx in sorted(self.input_col[1:], reverse=True): del data[idx] else: + # pyrefly: ignore # unsupported-operation data[self.input_col] = res else: if self.output_col == -1: diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index 063884d7bfa9..e9d19448a85c 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -42,6 +42,7 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]): "Sampler class requires input datapipe implemented `__len__`" ) super().__init__() + # pyrefly: ignore # bad-assignment self.datapipe = datapipe self.sampler_args = () if sampler_args is None else sampler_args self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index d1495dd81863..a62fc2a9cee5 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -59,6 +59,7 @@ class ConcaterIterDataPipe(IterDataPipe): def __len__(self) -> int: if all(isinstance(dp, Sized) for dp in self.datapipes): + # pyrefly: ignore # bad-argument-type return sum(len(dp) for dp in self.datapipes) else: raise TypeError(f"{type(self).__name__} instance doesn't have valid length") @@ -179,6 +180,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate): self._child_stop: list[bool] = [True for _ in range(num_instances)] def __len__(self): + # pyrefly: ignore # bad-argument-type return len(self.main_datapipe) def get_next_element_by_instance(self, instance_id: int): @@ -238,6 +240,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate): return self.end_ptr is not None and all(self._child_stop) def get_length_by_instance(self, instance_id: int) -> int: + # pyrefly: ignore # bad-argument-type return len(self.main_datapipe) def reset(self) -> None: @@ -323,6 +326,7 @@ class _ChildDataPipe(IterDataPipe): def __init__(self, main_datapipe: IterDataPipe, instance_id: int): assert isinstance(main_datapipe, _ContainerTemplate) + # pyrefly: ignore # bad-assignment self.main_datapipe: IterDataPipe = main_datapipe self.instance_id = instance_id @@ -449,6 +453,7 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate): drop_none: bool, buffer_size: int, ): + # pyrefly: ignore # invalid-type-var self.main_datapipe = datapipe self._datapipe_iterator: Optional[Iterator[Any]] = None self.num_instances = num_instances @@ -460,7 +465,9 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate): UserWarning, ) self.current_buffer_usage = 0 + # pyrefly: ignore # invalid-type-var self.child_buffers: list[deque[_T_co]] = [deque() for _ in range(num_instances)] + # pyrefly: ignore # invalid-type-var self.classifier_fn = classifier_fn self.drop_none = drop_none self.main_datapipe_exhausted = False @@ -698,6 +705,7 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]): def __len__(self) -> int: if all(isinstance(dp, Sized) for dp in self.datapipes): + # pyrefly: ignore # bad-argument-type return min(len(dp) for dp in self.datapipes) else: raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 72f4cb49b60d..74363a109a06 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -203,7 +203,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]): drop_remaining: bool = False, ): _check_unpickable_fn(group_key_fn) + # pyrefly: ignore # invalid-type-var self.datapipe = datapipe + # pyrefly: ignore # invalid-type-var self.group_key_fn = group_key_fn self.keep_key = keep_key @@ -214,9 +216,11 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]): self.guaranteed_group_size = None if group_size is not None and buffer_size is not None: assert 0 < group_size <= buffer_size + # pyrefly: ignore # bad-assignment self.guaranteed_group_size = group_size if guaranteed_group_size is not None: assert group_size is not None and 0 < guaranteed_group_size <= group_size + # pyrefly: ignore # bad-assignment self.guaranteed_group_size = guaranteed_group_size self.drop_remaining = drop_remaining self.wrapper_class = DataChunk diff --git a/torch/utils/data/datapipes/map/callable.py b/torch/utils/data/datapipes/map/callable.py index be55dc160d73..983ef41748d7 100644 --- a/torch/utils/data/datapipes/map/callable.py +++ b/torch/utils/data/datapipes/map/callable.py @@ -60,6 +60,7 @@ class MapperMapDataPipe(MapDataPipe[_T_co]): self.fn = fn # type: ignore[assignment] def __len__(self) -> int: + # pyrefly: ignore # bad-argument-type return len(self.datapipe) def __getitem__(self, index) -> _T_co: diff --git a/torch/utils/data/datapipes/map/combinatorics.py b/torch/utils/data/datapipes/map/combinatorics.py index 619d0e5c7a0e..b49619c12fd7 100644 --- a/torch/utils/data/datapipes/map/combinatorics.py +++ b/torch/utils/data/datapipes/map/combinatorics.py @@ -64,6 +64,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]): ) -> None: super().__init__() self.datapipe = datapipe + # pyrefly: ignore # bad-argument-type self.indices = list(range(len(datapipe))) if indices is None else indices self._enabled = True self._seed = None @@ -95,6 +96,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]): self._shuffled_indices = self._rng.sample(self.indices, len(self.indices)) def __len__(self) -> int: + # pyrefly: ignore # bad-argument-type return len(self.datapipe) def __getstate__(self): diff --git a/torch/utils/data/datapipes/map/combining.py b/torch/utils/data/datapipes/map/combining.py index 97f9ef142a7c..b4cb1add714f 100644 --- a/torch/utils/data/datapipes/map/combining.py +++ b/torch/utils/data/datapipes/map/combining.py @@ -49,13 +49,16 @@ class ConcaterMapDataPipe(MapDataPipe): def __getitem__(self, index) -> _T_co: # type: ignore[type-var] offset = 0 for dp in self.datapipes: + # pyrefly: ignore # bad-argument-type if index - offset < len(dp): return dp[index - offset] else: + # pyrefly: ignore # bad-argument-type offset += len(dp) raise IndexError(f"Index {index} is out of range.") def __len__(self) -> int: + # pyrefly: ignore # bad-argument-type return sum(len(dp) for dp in self.datapipes) @@ -102,4 +105,5 @@ class ZipperMapDataPipe(MapDataPipe[tuple[_T_co, ...]]): return tuple(res) def __len__(self) -> int: + # pyrefly: ignore # bad-argument-type return min(len(dp) for dp in self.datapipes) diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 55bceec0cb18..6edcee5e35b2 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -196,6 +196,7 @@ def get_file_pathnames_from_root( if match_masks(fname, masks): yield path else: + # pyrefly: ignore # bad-assignment for path, dirs, files in os.walk(root, onerror=onerror): if abspath: path = os.path.abspath(path) diff --git a/torch/utils/data/datapipes/utils/snapshot.py b/torch/utils/data/datapipes/utils/snapshot.py index d120025a934e..5d0f1c0dc84d 100644 --- a/torch/utils/data/datapipes/utils/snapshot.py +++ b/torch/utils/data/datapipes/utils/snapshot.py @@ -43,6 +43,7 @@ def _simple_graph_snapshot_restoration( # simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`, # the first reset will not actually reset. datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`. + # pyrefly: ignore # bad-argument-type apply_random_seed(datapipe, rng) remainder = n_iterations diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py index 949e3e0c23b4..6f818ff9dfa9 100644 --- a/torch/utils/data/distributed.py +++ b/torch/utils/data/distributed.py @@ -131,6 +131,7 @@ class DistributedSampler(Sampler[_T_co]): indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples + # pyrefly: ignore # bad-return return iter(indices) def __len__(self) -> int: diff --git a/torch/utils/data/graph.py b/torch/utils/data/graph.py index 26a4eae6d18c..63ac99c49268 100644 --- a/torch/utils/data/graph.py +++ b/torch/utils/data/graph.py @@ -72,6 +72,7 @@ def _list_connected_datapipes( p.dump(scan_obj) except (pickle.PickleError, AttributeError, TypeError): if dill_available(): + # pyrefly: ignore # missing-attribute d.dump(scan_obj) else: raise diff --git a/torch/utils/file_baton.py b/torch/utils/file_baton.py index 8437b45d1ffe..b493441db23a 100644 --- a/torch/utils/file_baton.py +++ b/torch/utils/file_baton.py @@ -31,6 +31,7 @@ class FileBaton: True if the file could be created, else False. """ try: + # pyrefly: ignore # bad-assignment self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL) return True except FileExistsError: diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index 45e35130b7b2..127bed2fc103 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -149,6 +149,7 @@ def conv_flop_count( @register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward]) def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int: """Count flops for convolution.""" + # pyrefly: ignore # bad-argument-type return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) @@ -676,7 +677,9 @@ class FlopCounterMode: if depth is None: depth = 999999 + import tabulate + # pyrefly: ignore # bad-assignment tabulate.PRESERVE_WHITESPACE = True header = ["Module", "FLOP", "% Total"] values = [] diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index a2bc9c31ff1b..54442fe403e9 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -48,6 +48,7 @@ MATH_TRANSPILATIONS = collections.OrderedDict( ] ) +# pyrefly: ignore # no-matching-overload CUDA_TYPE_NAME_MAP = collections.OrderedDict( [ ("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)), @@ -675,6 +676,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict( ] ) +# pyrefly: ignore # no-matching-overload CUDA_IDENTIFIER_MAP = collections.OrderedDict( [ ("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)), diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 0e816020635b..5b66392403b4 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -663,6 +663,7 @@ def is_caffe2_gpu_file(rel_filepath): return True filename = os.path.basename(rel_filepath) _, ext = os.path.splitext(filename) + # pyrefly: ignore # unsupported-operation return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename) class TrieNode: @@ -1137,6 +1138,7 @@ def hipify( out_of_place_only=out_of_place_only, is_pytorch_extension=is_pytorch_extension)) all_files_set = set(all_files) + # pyrefly: ignore # bad-assignment for f in extra_files: if not os.path.isabs(f): f = os.path.join(output_directory, f) diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index e6e93966afdb..a4431c8cc349 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -145,6 +145,7 @@ class BackwardHook: res = out + # pyrefly: ignore # bad-assignment self.grad_outputs = None return self._unpack_none(self.input_tensors_index, res) diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index dd56877c6cb8..ecf0b0fa0c6a 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -208,6 +208,7 @@ def get_model_info( with zipfile.ZipFile(path_or_file) as zf: path_prefix = None zip_files = [] + # pyrefly: ignore # bad-assignment for zi in zf.infolist(): prefix = re.sub("/.*", "", zi.filename) if path_prefix is None: @@ -359,9 +360,12 @@ def get_inline_skeleton(): import importlib.resources + # pyrefly: ignore # bad-argument-type skeleton = importlib.resources.read_text(__package__, "skeleton.html") + # pyrefly: ignore # bad-argument-type js_code = importlib.resources.read_text(__package__, "code.js") for js_module in ["preact", "htm"]: + # pyrefly: ignore # bad-argument-type js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs") js_url = "data:application/javascript," + urllib.parse.quote(js_lib) js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url) diff --git a/torch/utils/tensorboard/_convert_np.py b/torch/utils/tensorboard/_convert_np.py index afa801343334..21290a8b0ced 100644 --- a/torch/utils/tensorboard/_convert_np.py +++ b/torch/utils/tensorboard/_convert_np.py @@ -31,5 +31,7 @@ def make_np(x: torch.Tensor) -> np.ndarray: def _prepare_pytorch(x: torch.Tensor) -> np.ndarray: if x.dtype == torch.bfloat16: x = x.to(torch.float16) + # pyrefly: ignore # bad-assignment x = x.detach().cpu().numpy() + # pyrefly: ignore # bad-return return x diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index 85427162fc77..1577516b3f6d 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -188,6 +188,7 @@ class GraphPy: for key, node in self.nodes_io.items(): if type(node) == NodeBase: + # pyrefly: ignore # unsupported-operation self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName if hasattr(node, "input_or_output"): self.unique_name_to_scoped_name[key] = ( @@ -198,6 +199,7 @@ class GraphPy: self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName if node.scope == "" and self.shallowest_scope_name: self.unique_name_to_scoped_name[node.debugName] = ( + # pyrefly: ignore # unsupported-operation self.shallowest_scope_name + "/" + node.debugName ) diff --git a/torch/utils/tensorboard/_utils.py b/torch/utils/tensorboard/_utils.py index f0ad185d968f..8d9e4a8e09b6 100644 --- a/torch/utils/tensorboard/_utils.py +++ b/torch/utils/tensorboard/_utils.py @@ -57,11 +57,14 @@ def _prepare_video(V): return num != 0 and ((num & (num - 1)) == 0) # pad to nearest power of 2, all at once + # pyrefly: ignore # index-error if not is_power2(V.shape[0]): + # pyrefly: ignore # index-error len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0]) V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0) n_rows = 2 ** ((b.bit_length() - 1) // 2) + # pyrefly: ignore # index-error n_cols = V.shape[0] // n_rows V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w)) diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index 3fca4d9b7e66..682848b12d9e 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -9,6 +9,7 @@ from typing import Any, Optional import torch import numpy as np +# pyrefly: ignore # import-error from google.protobuf import struct_pb2 from tensorboard.compat.proto.summary_pb2 import ( @@ -497,6 +498,7 @@ def make_histogram(values, bins, max_bins=None): subsampling = num_bins // max_bins subsampling_remainder = num_bins % subsampling if subsampling_remainder != 0: + # pyrefly: ignore # no-matching-overload counts = np.pad( counts, pad_width=[[0, subsampling - subsampling_remainder]], @@ -834,17 +836,21 @@ def compute_curve(labels, predictions, num_thresholds=None, weights=None): weights = 1.0 # Compute bins of true positives and false positives. + # pyrefly: ignore # unsupported-operation bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) float_labels = labels.astype(np.float64) + # pyrefly: ignore # unsupported-operation histogram_range = (0, num_thresholds - 1) tp_buckets, _ = np.histogram( bucket_indices, + # pyrefly: ignore # bad-argument-type bins=num_thresholds, range=histogram_range, weights=float_labels * weights, ) fp_buckets, _ = np.histogram( bucket_indices, + # pyrefly: ignore # bad-argument-type bins=num_thresholds, range=histogram_range, weights=(1.0 - float_labels) * weights, diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index 129281cb8ac3..8add89f236b6 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -254,7 +254,9 @@ class SummaryWriter: buckets = [] neg_buckets = [] while v < 1e20: + # pyrefly: ignore # bad-argument-type buckets.append(v) + # pyrefly: ignore # bad-argument-type neg_buckets.append(-v) v *= 1.1 self.default_bins = neg_buckets[::-1] + [0] + buckets @@ -262,15 +264,19 @@ class SummaryWriter: def _get_file_writer(self): """Return the default FileWriter instance. Recreates it if closed.""" if self.all_writers is None or self.file_writer is None: + # pyrefly: ignore # bad-assignment self.file_writer = FileWriter( self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix ) + # pyrefly: ignore # bad-assignment, missing-attribute self.all_writers = {self.file_writer.get_logdir(): self.file_writer} if self.purge_step is not None: most_recent_step = self.purge_step + # pyrefly: ignore # missing-attribute self.file_writer.add_event( Event(step=most_recent_step, file_version="brain.Event:2") ) + # pyrefly: ignore # missing-attribute self.file_writer.add_event( Event( step=most_recent_step, @@ -950,6 +956,7 @@ class SummaryWriter: ) self._projector_config.embeddings.extend([embedding_info]) + # pyrefly: ignore # import-error from google.protobuf import text_format config_pbtxt = text_format.MessageToString(self._projector_config) @@ -1199,6 +1206,7 @@ class SummaryWriter: for writer in self.all_writers.values(): writer.flush() writer.close() + # pyrefly: ignore # bad-assignment self.file_writer = self.all_writers = None def __enter__(self): diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 79d8e8b8b171..5ed15c557265 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -461,6 +461,7 @@ def to_html(nodes): if n.context is None: continue s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}')) + # pyrefly: ignore # bad-argument-type listeners.append(s) dot = to_dot(nodes) return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners)) diff --git a/torch/utils/weak.py b/torch/utils/weak.py index 9c7218cb2ad3..cb8862e64531 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -292,6 +292,7 @@ class WeakIdKeyDictionary(MutableMapping): if o is not None: return o, value + # pyrefly: ignore # bad-override def pop(self, key, *args): self._dirty_len = True return self.data.pop(self.ref_type(key), *args) # CHANGED diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 0c7f4cd3ec6b..d1ceb8df2b00 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -240,6 +240,7 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: # Only keep attributes that are safe for dictionary serialization. serializable_types = (int, float, bool, str, type(None), list, tuple, dict) return { + # pyrefly: ignore # unbound-name key: value for key in dir(props) if not key.startswith("__")