Pyrefly suppressions 6/n (#164877)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

INFO 0 errors (5,064 ignored)

Only four directories left to enable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164877
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-08 02:30:53 +00:00
committed by PyTorch MergeBot
parent ad7b2bebc6
commit 086dec3235
123 changed files with 355 additions and 72 deletions

View File

@ -24,12 +24,12 @@ project-excludes = [
"torch/distributed/**",
"torch/nn/**",
"torch/_dynamo/**",
"torch/utils/**",
# formatting issues
"torch/linalg/__init__.py",
"torch/package/importer.py",
"torch/package/_package_pickler.py",
"torch/jit/annotations.py",
"torch/utils/data/datapipes/_typing.py",
# ====
"benchmarks/instruction_counts/main.py",
"benchmarks/instruction_counts/definitions/setup.py",

View File

@ -59,6 +59,7 @@ class TestBundledInputs(TestCase):
# despite having nominally large bundled inputs.
augmented_size = model_size(sm)
# pyrefly: ignore # missing-attribute
self.assertLess(augmented_size, original_size + (1 << 12))
loaded = save_and_load(sm)
@ -66,12 +67,15 @@ class TestBundledInputs(TestCase):
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
self.assertEqual(len(inflated), len(samples))
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
for idx, inp in enumerate(inflated):
# pyrefly: ignore # missing-attribute
self.assertIsInstance(inp, tuple)
self.assertEqual(len(inp), 1)
# pyrefly: ignore # missing-attribute
self.assertIsInstance(inp[0], torch.Tensor)
if idx != 5:
# Strides might be important for benchmarking.
@ -140,6 +144,7 @@ class TestBundledInputs(TestCase):
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(inflated, samples)
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) == "first 1")
def test_multiple_methods_with_inputs(self):
@ -187,6 +192,7 @@ class TestBundledInputs(TestCase):
# Check running and size helpers
# pyrefly: ignore # missing-attribute
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
@ -420,6 +426,7 @@ class TestBundledInputs(TestCase):
augmented_size = model_size(sm)
# assert the size has not increased more than 8KB
# pyrefly: ignore # missing-attribute
self.assertLess(augmented_size, original_size + (1 << 13))
loaded = save_and_load(sm)

View File

@ -48,6 +48,7 @@ class TestComplexTensor(TestCase):
def test_all(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/120875
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
# pyrefly: ignore # missing-attribute
self.assertTrue(torch.all(x))
@dtypes(*complex_types())
@ -56,6 +57,7 @@ class TestComplexTensor(TestCase):
x = torch.tensor(
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
)
# pyrefly: ignore # missing-attribute
self.assertFalse(torch.any(x))
@onlyCPU

View File

@ -142,6 +142,7 @@ class TestTypeHints(TestCase):
]
)
if result != 0:
# pyrefly: ignore # missing-attribute
self.fail(f"mypy failed:\n{stderr}\n{stdout}")

View File

@ -125,6 +125,7 @@ class TestDTypeInfo(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/124868
# If reference count is leaked this would be a set of 10 elements
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
# pyrefly: ignore # missing-attribute
self.assertLess(len(ref_cnt), 3)
self.assertEqual(torch.float64.to_complex(), torch.complex128)
@ -135,6 +136,7 @@ class TestDTypeInfo(TestCase):
# Regression test for https://github.com/pytorch/pytorch/issues/124868
# If reference count is leaked this would be a set of 10 elements
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
# pyrefly: ignore # missing-attribute
self.assertLess(len(ref_cnt), 3)
self.assertEqual(torch.complex128.to_real(), torch.double)

View File

@ -240,7 +240,7 @@ def get_decompositions(
registry = global_decomposition_table[type]
packets_to_overloads = defaultdict(list)
# pyrefly: ignore # bad-assignment
for opo in registry:
if isinstance(opo, (OpOverload, OpOverloadPacket)):
packets_to_overloads[opo.overloadpacket].append(opo)

View File

@ -4079,6 +4079,7 @@ def _nll_loss_forward(
return result, total_weight
if weight is not None:
# pyrefly: ignore # unbound-name
w = w.expand(self.shape)
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
wsum = torch.where(target != ignore_index, wsum, 0)

View File

@ -341,8 +341,8 @@ class MtiaInterface(DeviceInterface):
synchronize = staticmethod(torch.mtia.synchronize)
get_device_properties = staticmethod(torch.mtia.get_device_properties) # type: ignore[assignment]
get_raw_stream = staticmethod(get_mtia_stream) # type: ignore[assignment, arg-type]
exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type]
maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type]
exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type, has-type]
maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type, has-type]
memory_allocated = staticmethod(torch.mtia.memory_allocated) # type: ignore[assignment]
is_bf16_supported = staticmethod(torch.mtia.is_bf16_supported) # type: ignore[arg-type]
@ -414,7 +414,7 @@ class XpuInterface(DeviceInterface):
current_device = staticmethod(torch.xpu.current_device)
set_device = staticmethod(torch.xpu.set_device)
device_count = staticmethod(torch.xpu.device_count)
device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type]
stream = staticmethod(torch.xpu.stream) # type: ignore[assignment]
current_stream = staticmethod(torch.xpu.current_stream)
set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment]
@ -422,8 +422,8 @@ class XpuInterface(DeviceInterface):
synchronize = staticmethod(torch.xpu.synchronize)
get_device_properties = staticmethod(torch.xpu.get_device_properties) # type: ignore[assignment]
get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[assignment, arg-type]
exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type]
maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type]
exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type, has-type]
maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type, has-type]
memory_allocated = staticmethod(torch.xpu.memory_allocated)
# Can be mock patched by @patch decorator.

View File

@ -1097,7 +1097,6 @@ class TS2FXGraphConverter:
# Update the value of loop local variables.
if node.outputsSize() >= 1:
# pyrefly: ignore # bad-assignment
for i, outp in enumerate(node.outputs()):
output_name = outp.debugName()
self.name_to_node[output_name] = self.fx_graph.call_function(
@ -1110,7 +1109,7 @@ class TS2FXGraphConverter:
fx_block_args[i] = self.name_to_node[output_name]
# Update the value of global variables, whose values are modified inplace.
# pyrefly: ignore # bad-assignment
for i, name in enumerate(
subgraph_converter.name_update_from_subblock_to_parent
):

View File

@ -140,7 +140,7 @@ def key_path_to_source(
source: Source = LocalSource("args")
else:
source, kp = sourced_prefixes.get(kp)
# pyrefly: ignore # bad-assignment
for k in kp:
if isinstance(k, SequenceKey):
source = GetItemSource(source, k.idx)

View File

@ -317,6 +317,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
)
res_proxy.node.meta.update(meta.data)
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
# pyrefly: ignore # unbound-name
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
self.tracer.set_metadata(res_proxy.node, res_data)

View File

@ -83,7 +83,6 @@ def _node_metadata_hook(
node.meta["torch_fn"] = node.meta.get(
"torch_fn",
(
# pyrefly: ignore # missing-attribute
f"{node.target.__name__}_0",
# pyrefly: ignore # missing-attribute
f"{node.target.__class__.__name__}.{node.target.__name__}",

View File

@ -646,6 +646,7 @@ def update_schema():
assert thrift_content[1].startswith("// checksum<<")
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
# pyrefly: ignore # import-error
from yaml import load, Loader
dst = load(content, Loader=Loader)

View File

@ -2183,6 +2183,7 @@ class GraphModuleDeserializer(metaclass=Final):
simplify=True,
)
):
# pyrefly: ignore # unbound-name
node.meta["unbacked_bindings"] = unbacked_bindings
assert len(self.unbacked_symbols) == 0

View File

@ -471,7 +471,6 @@ def _check_input_constraints_for_graph(
elif isinstance(node_val, torch.SymInt):
_check_symint(
node_val,
# pyrefly: ignore # bad-argument-type
arg,
range_constraints,
unification_map,

View File

@ -204,6 +204,7 @@ def run_functionalized_fw_and_collect_metadata(
suppress_pending = contextlib.nullcontext()
fake_mode = detect_fake_mode()
if fake_mode and (shape_env := fake_mode.shape_env):
# pyrefly: ignore # unbound-name
suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
with disable_above, mode, suppress_pending:
# precondition: The passed in function already handles unflattening inputs + flattening outputs

View File

@ -439,7 +439,7 @@ def collect_fw_donated_buffer_idxs(
"""
storage_refs = set()
# pyrefly: ignore # bad-assignment
for t in itertools.chain(fw_ins, user_fw_outs, bw_outs):
# Only access storage if a tensor has storage (not sparse)
if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t):

View File

@ -196,6 +196,7 @@ class MemoryFormatMeta:
if use_memory_format:
return MemoryFormatMeta(
# pyrefly: ignore # unbound-name
memory_format=torch._prims_common.suggest_memory_format(t),
)
@ -892,12 +893,15 @@ class GraphSignature:
parameters_to_mutate = {}
for output_name, mutation_name in outputs_to_mutations.items():
if mutation_name in user_inputs:
# pyrefly: ignore # unsupported-operation
user_inputs_to_mutate[output_name] = mutation_name
else:
assert mutation_name in buffers or mutation_name in parameters
if mutation_name in buffers:
# pyrefly: ignore # unsupported-operation
buffers_to_mutate[output_name] = mutation_name
else:
# pyrefly: ignore # unsupported-operation
parameters_to_mutate[output_name] = mutation_name
start, stop = stop, stop + num_user_outputs

View File

@ -232,7 +232,6 @@ def unwrap_tensor_subclasses(
attrs, _ = t.__tensor_flatten__()
# pyrefly: ignore # bad-assignment
for attr in attrs:
inner_tensor = getattr(t, attr)
n_desc: Any = (
@ -314,7 +313,6 @@ def runtime_unwrap_tensor_subclasses(
for idx, x in enumerate(wrapped_args):
if not is_traceable_wrapper_subclass(x):
# pyrefly: ignore # bad-argument-type
xs_inner.append(x)
continue

View File

@ -199,6 +199,7 @@ def _extract_graph_with_inputs_outputs(
new_node = new_graph.placeholder(node.name)
# Can't use node_copy here as we may be turning previous call_function into placeholders
new_node.meta = node.meta
# pyrefly: ignore # unsupported-operation
env[node] = new_node
for node in joint_graph.nodes:
@ -227,8 +228,10 @@ def _extract_graph_with_inputs_outputs(
if any(all_args):
env[node] = InvalidNode # type: ignore[assignment]
continue
# pyrefly: ignore # unsupported-operation, bad-argument-type
env[node] = new_graph.node_copy(node, lambda x: env[x])
elif node.op == "get_attr":
# pyrefly: ignore # unsupported-operation, bad-argument-type
env[node] = new_graph.node_copy(node, lambda x: env[x])
elif node.op == "output":
pass
@ -1403,12 +1406,14 @@ def functionalize_rng_ops(
devices = OrderedSet(
get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values()
)
# pyrefly: ignore # unbound-name
devices.discard(torch.device("cpu"))
# multiple cuda devices won't work with cudagraphs anyway,
# fallback to non graphsafe rng checkpointing
multi_cuda_devices = len(devices) > 1
# this changes numerics, so if fallback_random is set we will not use it
# pyrefly: ignore # unbound-name
ind_config = torch._inductor.config
use_rng_graphsafe_rng_functionalization = (
config.graphsafe_rng_functionalization
@ -2840,6 +2845,7 @@ def min_cut_rematerialization_partition(
node_info,
memory_budget=memory_budget,
)
# pyrefly: ignore # unbound-name
if config._sync_decision_cross_ranks:
saved_values = _sync_decision_cross_ranks(joint_graph, saved_values)
# save_for_backward on tensors and stashes symints in autograd .ctx

View File

@ -57,6 +57,7 @@ def _interleave(a, b, dim=0):
stacked = torch.stack([a, b], dim=dim + 1)
interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1)
# pyrefly: ignore # unbound-name
if b_trunc:
# TODO: find torch alternative for slice_along dim for torch.jit.script to work
interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1)

View File

@ -746,6 +746,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
and (shape_env := loop_count.node.shape_env)
and loop_count in shape_env.pending_fresh_unbacked_symbols
):
# pyrefly: ignore # unbound-name
shape_env.pending_fresh_unbacked_symbols.remove(loop_count)
# Even when body function is not executed, we clone and unsqueeze the input

View File

@ -198,6 +198,7 @@ def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str:
to a file. The yaml string can be loaded back into an operator profile
structure using `read_profiles_from_yaml`.
"""
# pyrefly: ignore # import-error
import yaml
from torch._export.serde.serialize import (
@ -262,6 +263,7 @@ def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]:
"""
Reads the yaml saved by `save_op_profiles` and returns the operator profiles.
"""
# pyrefly: ignore # import-error
import yaml
from torch._export.serde.serialize import (

View File

@ -914,6 +914,7 @@ class TorchLogsFormatter(logging.Formatter):
and (trace_id := torch._guards.CompileContext.current_trace_id())
is not None
):
# pyrefly: ignore # unbound-name
record.traceid = f" [{trace_id}]"
glog_level_to_abbr = {

View File

@ -113,7 +113,6 @@ def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:
if len(a) != len(b):
return False
# pyrefly: ignore # bad-assignment
for x, y in zip(a, b):
if allow_rhs_unbacked:
if isinstance(y, torch.SymInt):

View File

@ -6682,7 +6682,7 @@ def _infer_scalar_type(obj):
# double.
if length == 0:
return torch.get_default_dtype()
# pyrefly: ignore # bad-assignment
for i in range(length):
cur_item = obj[i]
# TODO: test this

View File

@ -106,13 +106,12 @@ def _resize_fft_input(
if x_sizes[dims[i]] < sizes[i]:
must_copy = True
pad_idx = len(pad_amount) - 2 * dims[i] - 1
# pyrefly: ignore # unsupported-operation
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
if x_sizes[dims[i]] > sizes[i]:
x = x.narrow(dims[i], 0, sizes[i])
# pyrefly: ignore # bad-argument-type
return torch.constant_pad_nd(x, pad_amount) if must_copy else x

View File

@ -1374,6 +1374,7 @@ class FakeTensorMode(TorchDispatchMode):
return self._stack
@count
# pyrefly: ignore # bad-override
def __torch_dispatch__(
self,
func: OpOverload,
@ -2624,6 +2625,7 @@ class FakeTensorMode(TorchDispatchMode):
and s.rhs == 1
):
assert self.shape_env is not None
# pyrefly: ignore # unbound-name
self.shape_env.set_unbacked_var_to_val(s, int(real_t))
if real_out is not nil:

View File

@ -1820,6 +1820,7 @@ class MetaConverter(Generic[_TensorT]):
# TODO: Use a valid grad-specific symbolic context instead of recycling
# the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view().
# pyrefly: ignore # unbound-name
r.grad = self.meta_tensor(
t.grad,
shape_env,
@ -1827,12 +1828,15 @@ class MetaConverter(Generic[_TensorT]):
AttrSource(source, "grad"),
symbolic_context,
)
# pyrefly: ignore # unbound-name
torch._C._set_conj(r, t.is_conj)
# pyrefly: ignore # unbound-name
torch._C._set_neg(r, t.is_neg)
# This can be skipped if necessary for performance reasons
skip_leaf = (
t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
)
# pyrefly: ignore # unbound-name
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
# Thanks to storage resizing, it's possible to end up with a tensor
# that advertises a real size, but has a storage that actually has zero bytes.
@ -1840,14 +1844,18 @@ class MetaConverter(Generic[_TensorT]):
from torch.fx.experimental.symbolic_shapes import guard_or_false
if t.storage is not None and guard_or_false(t.storage.size == 0):
# pyrefly: ignore # unbound-name
r.untyped_storage().resize_(0)
if t.is_parameter:
# pyrefly: ignore # unbound-name
r._is_param = True
# See Note: [Creating symbolic nested int]
if t.nested_int is not None:
# pyrefly: ignore # unbound-name
assert _is_fake_tensor(r)
# pyrefly: ignore # unbound-name
r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
nt_tensor_id=t.nested_int
)

View File

@ -1120,6 +1120,7 @@ class Tensor(torch._C.TensorBase):
__rtruediv__ = __rdiv__
__itruediv__ = _C.TensorBase.__idiv__
# pyrefly: ignore # bad-override
__pow__ = cast(
Callable[
["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]],

View File

@ -657,8 +657,10 @@ def _str_intern(inp, *, tensor_contents=None):
grad_fn_name = "Invalid"
if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
grad_fn_name = type(grad_fn).__name__
if grad_fn_name == "CppFunction":
# pyrefly: ignore # unbound-name
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
if grad_fn_name is not None:

View File

@ -89,6 +89,7 @@ def compile_time_strobelight_meta(
skip := kwargs["skip"],
int,
):
# pyrefly: ignore # unbound-name
kwargs["skip"] = skip + 1
# This is not needed but we have it here to avoid having profile_compile_time

View File

@ -951,6 +951,7 @@ def create_a_shadows_b(
if should_log_inputs:
# skip the input logger when inserting a dtype cast
if isinstance(prev_node_c, Node):
# pyrefly: ignore # unbound-name
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
elif isinstance(prev_node_c, list):
prev_node_c = [
@ -959,6 +960,7 @@ def create_a_shadows_b(
]
dtype_cast_node = _insert_dtype_cast_after_node(
subgraph_a.start_node,
# pyrefly: ignore # unbound-name
node_c,
prev_node_c,
gm_a,
@ -1039,7 +1041,10 @@ def create_a_shadows_b(
if num_non_param_args_node_a == 2:
# node_c_second_non_param_arg = node_c.args[1]
node_c_second_non_param_arg = get_normalized_nth_input(
node_c, gm_b, 1
# pyrefly: ignore # unbound-name
node_c,
gm_b,
1,
)
node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
dtype_cast_node,
@ -1047,6 +1052,7 @@ def create_a_shadows_b(
subgraph_a,
gm_a,
gm_b,
# pyrefly: ignore # unbound-name
node_c.name + "_shadow_copy_",
)
env_c[node_a_shadows_c.name] = node_a_shadows_c
@ -1069,11 +1075,15 @@ def create_a_shadows_b(
cur_node = node_a_shadows_c
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
# pyrefly: ignore # unbound-name
if isinstance(input_logger, Node):
# pyrefly: ignore # unbound-name
input_logger_mod = getattr(gm_b, input_logger.name)
input_logger_mod.ref_node_name = cur_node.name
else:
# pyrefly: ignore # unbound-name
assert isinstance(input_logger, list)
# pyrefly: ignore # unbound-name
for input_logger_inner in input_logger:
input_logger_mod = getattr(gm_b, input_logger_inner.name)
input_logger_mod.ref_node_name = cur_node.name

View File

@ -93,6 +93,7 @@ class OutputProp:
)
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
node.traced_result = result
# pyrefly: ignore # unsupported-operation

View File

@ -404,7 +404,7 @@ def maybe_add_missing_fqns(results: NSResultsType) -> None:
for model_name, model_results in model_name_to_results.items():
if model_name == model_name_with_fqns:
continue
# pyrefly: ignore # bad-assignment
for i in range(len(model_results)):
fqn = ref_model_results[i]["fqn"]
model_results[i]["fqn"] = fqn

View File

@ -27,6 +27,7 @@ def run_forward(model, **batch):
model(X, lS_o, lS_i)
end = time.time()
time_taken = end - start
# pyrefly: ignore # bad-argument-type
time_list.append(time_taken)
avg_time = np.mean(time_list[1:])
return avg_time

View File

@ -127,6 +127,7 @@ def _prune_linear_helper(linear: nn.Linear) -> Tensor:
linear.out_features = linear.weight.shape[0]
_remove_bias_handles(linear)
# pyrefly: ignore # unbound-name
return mask
@ -185,6 +186,7 @@ def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
conv2d.out_channels = conv2d.weight.shape[0]
_remove_bias_handles(conv2d)
# pyrefly: ignore # unbound-name
return mask
@ -205,6 +207,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
new_bias = torch.zeros(conv2d_1.bias.shape)
new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined]
# adjusted bias that to keep in conv2d_1
# pyrefly: ignore # unbound-name
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
# pruned biases that are kept instead of propagated
conv2d_1.bias = nn.Parameter(new_bias)

View File

@ -72,7 +72,6 @@ def _find_q_dq_node_for_user(
dq_node = n
break
if dq_node is None:
# pyrefly: ignore # bad-assignment
for n in user.kwargs:
if (
isinstance(n, torch.fx.Node)
@ -91,6 +90,7 @@ def _find_q_dq_node_for_user(
and arg.op == "call_function"
and arg.target in _QUANTIZE_OPS
):
# pyrefly: ignore # unbound-name
q_node = arg
return (q_node, dq_node)

View File

@ -414,5 +414,6 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
def __enter__(self) -> None:
pass
# pyrefly: ignore # bad-override
def __exit__(self, *args) -> None:
torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions)

View File

@ -48,14 +48,11 @@ class EventList(list):
def _remove_dup_nodes(self):
while True:
to_delete = set()
# pyrefly: ignore # bad-assignment
for idx in range(len(self)):
if (
# pyrefly: ignore # index-error
self[idx].cpu_parent is not None
# pyrefly: ignore # index-error
and self[idx].cpu_parent.name == self[idx].name
# pyrefly: ignore # index-error
and len(self[idx].cpu_parent.cpu_children) == 1
):
self[idx].cpu_parent.cpu_children = self[idx].cpu_children
@ -65,11 +62,11 @@ class EventList(list):
to_delete.add(idx)
if len(to_delete) == 0:
break
# pyrefly: ignore # bad-argument-type
new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
# pyrefly: ignore # missing-attribute
self.clear()
# pyrefly: ignore # missing-attribute
self.extend(new_evts)
def _populate_cpu_children(self):

View File

@ -90,6 +90,7 @@ class CacheArtifactFactory:
@classmethod
def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact:
artifact_cls = cls._get_artifact_type(artifact_type_key)
# pyrefly: ignore # bad-instantiation
return artifact_cls(key, content)
@classmethod
@ -97,6 +98,7 @@ class CacheArtifactFactory:
cls, artifact_type_key: str, key: str, content: Any
) -> CacheArtifact:
artifact_cls = cls._get_artifact_type(artifact_type_key)
# pyrefly: ignore # bad-instantiation
return artifact_cls(key, artifact_cls.encode(content))

View File

@ -377,6 +377,7 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
nn_module_stack = {
root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),
# pyrefly: ignore # unbound-name
**nn_module_stack,
}
node.meta["nn_module_stack"] = {
@ -525,6 +526,7 @@ def _replace_unbacked_bindings(gm: torch.fx.GraphModule) -> None:
simplify=True,
)
):
# pyrefly: ignore # unbound-name
node.meta["unbacked_bindings"] = unbacked_bindings
@ -662,7 +664,6 @@ def _rename_constants_nodes(
if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
const_prefix
):
# pyrefly: ignore # bad-argument-type
if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants
c_name = rename_constant(
const_prefix + spec.arg.name[len(buffer_prefix) :]

View File

@ -332,7 +332,9 @@ class _TorchNumpyPickleData:
if not (name := getattr(np, "__name__", None)):
return None
# pyrefly: ignore # unbound-name
assert np == getattr(importlib.import_module(mod), name)
# pyrefly: ignore # unbound-name
return cls(mod, name)

View File

@ -1793,17 +1793,14 @@ class _ModuleStackTracer(PythonKeyTracer):
self.enable_attr_proxy = False
self.submodule_paths = {}
for name, m in self.scope_root.named_modules(remove_duplicate=False):
# pyrefly: ignore # unsupported-operation
if m in self.submodule_paths:
log.info(
"Shared module found between %s and %s, AttrProxy is enabled.",
# pyrefly: ignore # unsupported-operation
self.submodule_paths[m],
name,
)
self.enable_attr_proxy = True
else:
# pyrefly: ignore # unsupported-operation
self.submodule_paths[m] = name
self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary()
@ -2365,6 +2362,7 @@ class _MakefxTracer:
):
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
# pyrefly: ignore # unbound-name
insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx")
t.recompile()
# TODO: kind of a bad way to do it, should maybe figure out a better way

View File

@ -620,11 +620,13 @@ def rebind_unbacked(
):
# This is what the pattern match above is testing
repacked = _sympy_cast_symbool_to_symint_guardless(
# pyrefly: ignore # unbound-name
sympy.Eq(new_raw_u1, 1)
)
assert repacked == raw_u1, f"{repacked} != {raw_u1}"
# Cancel the to_int(to_bool(x)). This is sound because x in
# [0, 1]
# pyrefly: ignore # unbound-name
raw_u1 = new_raw_u1
if not isinstance(raw_u1, sympy.Symbol):
@ -1025,6 +1027,7 @@ def find_symbol_binding_fx_nodes(
# NB: Prefer first occurrence of symbol
for node in graph.nodes:
if (s := is_symbol_binding_fx_node(node)) is not None and s not in r:
# pyrefly: ignore # unbound-name
r[s] = node
return r
@ -1195,10 +1198,13 @@ def _free_unbacked_symbols_with_path(
and isinstance(s := expr(a), sympy.Symbol)
and s in pending
):
# pyrefly: ignore # unbound-name
r[s] = path
if shape_env and real is not None:
assert isinstance(real, (int, float))
# pyrefly: ignore # unbound-name
shape_env.set_unbacked_var_to_val(s, real)
# pyrefly: ignore # unbound-name
pending.remove(s)
# When an unbacked SymInt is perfectly divisible by an integer
# constant, we replace it with the integer constant to improve
@ -1228,20 +1234,27 @@ def _free_unbacked_symbols_with_path(
source=shape_env.var_to_sources.get(s, [None])[0], # type: ignore[union-attr]
)
# pyrefly: ignore # unbound-name
unbacked = lhs if lhs in pending else rhs
divisor: IntLikeType = (
# pyrefly: ignore # unbound-name
int(coeff)
# pyrefly: ignore # unbound-name
if shape_env and isinstance(coeff, sympy.Integer)
# pyrefly: ignore # unbound-name
else _symint_wrap(coeff)
)
# TODO: DivideByKey needs to test divisibility at runtime!
# pyrefly: ignore # unsupported-operation
r[unbacked] = path + (DivideByKey(divisor),)
if real is not None:
assert isinstance(real, int)
val = (
# pyrefly: ignore # unbound-name
real // int(coeff)
# pyrefly: ignore # unbound-name
if isinstance(coeff, sympy.Integer)
# pyrefly: ignore # unbound-name
else CleanDiv(real, coeff)
)
if shape_env:
@ -1263,7 +1276,9 @@ def _free_unbacked_symbols_with_path(
if real is not None:
assert type(real) is bool
if shape_env:
# pyrefly: ignore # unbound-name
shape_env.set_unbacked_var_to_val(s, int(real))
# pyrefly: ignore # unbound-name
pending.remove(s.lhs)
return r
@ -1339,6 +1354,7 @@ def compute_unbacked_bindings(
):
if (
isinstance(old_sym, SymTypes)
# pyrefly: ignore # unbound-name
and (old_s := old_sym.node.expr) != new_s
):
# If old_s is not an unbacked_symbol,
@ -1348,11 +1364,15 @@ def compute_unbacked_bindings(
# and the original symbol gets replaced by the backed symbol.
# When this happens we just replace new_s by the old_s
# because we know the value is the same.
# pyrefly: ignore # unbound-name
if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s):
# pyrefly: ignore # unbound-name
shape_env._rename_unbacked_to(new_s, old_s)
else:
# pyrefly: ignore # unbound-name
shape_env._eliminate_unbacked(new_s, old_s)
elif not isinstance(old_sym, SymTypes):
# pyrefly: ignore # unbound-name
shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym))
return symbol_to_path
@ -3317,6 +3337,7 @@ class DimConstraints:
and str(symbol := next(iter(c["eq"].free_symbols))) == old_root
): # derived dim with root = old_root
new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1
# pyrefly: ignore # unbound-name
new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1
c["eq"] = new_expr
@ -5313,7 +5334,7 @@ class ShapeEnv:
]
else:
assert len(input_contexts) == len(placeholders)
# pyrefly: ignore # bad-assignment
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
if isinstance(t, Tensorlike):
if context is None:
@ -5663,13 +5684,12 @@ class ShapeEnv:
)
track_symint(property_source, ss, constraint_size[i])
else:
# pyrefly: ignore # missing-attribute
for i, ss in enumerate(curr_t.size()):
property_source = TensorPropertySource(
src, TensorProperty.SIZE, i
)
track_symint(property_source, ss, constraint_size[i])
# pyrefly: ignore # missing-attribute
for i, ss in enumerate(curr_t.stride()):
property_source = TensorPropertySource(
src, TensorProperty.STRIDE, i
@ -5677,7 +5697,6 @@ class ShapeEnv:
track_symint(property_source, ss, constraint_stride[i])
track_symint(
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
# pyrefly: ignore # missing-attribute
curr_t.storage_offset(),
)
@ -5723,7 +5742,6 @@ class ShapeEnv:
continue
if is_dim(source):
# pyrefly: ignore # missing-attribute
self.dim_constraints.add_equality(source, expr)
for exprs, printer, lang in zip(all_exprs, printers, langs):
@ -5877,7 +5895,6 @@ class ShapeEnv:
continue
expr = self.simplify(ra.expr)
# pyrefly: ignore # missing-attribute
self.dim_constraints.add(expr)
# 3. Every symbol must be within its value range (this handles 0/1
@ -5894,7 +5911,6 @@ class ShapeEnv:
verbose_expr = ""
if r.lower not in (-sympy.oo, -int_oo):
if any(is_dim(source) for source in sources):
# pyrefly: ignore # missing-attribute
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
# Only print lower bound in simplified mode if it is not the
# default
@ -5903,7 +5919,6 @@ class ShapeEnv:
verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}"
if r.upper not in (sympy.oo, int_oo):
if any(is_dim(source) for source in sources):
# pyrefly: ignore # missing-attribute
self.dim_constraints.add(sympy.Le(symbol, r.upper))
# nontrivial upper bound is always interesting
bounds.append(sympy.Le(symbol, r.upper, evaluate=False))
@ -6152,7 +6167,6 @@ class ShapeEnv:
else:
bindings[-s] = -arg
# pyrefly: ignore # bad-assignment
for t, arg in zip(placeholders, args):
if t is None:
continue
@ -7588,8 +7602,10 @@ class ShapeEnv:
log.info(
"oblivious_size %s -> %s (passed counterfactual)",
orig_expr,
# pyrefly: ignore # unbound-name
correct_hint,
)
# pyrefly: ignore # unbound-name
concrete_val = correct_hint
# NB: do NOT transmute into runtime assert
ok = True
@ -7606,8 +7622,10 @@ class ShapeEnv:
).xreplace(self.var_to_val)
).free_symbols
):
# pyrefly: ignore # unbound-name
self._log_real_tensor_propagation(orig_expr, unsound_result)
transmute_into_runtime_assert = True
# pyrefly: ignore # unbound-name
concrete_val = unsound_result
ok = True
@ -8035,7 +8053,6 @@ def _suggest_fixes_for_data_dependent_error_non_strict(
if isinstance(leaf, torch.SymInt):
src_map[str(leaf.node.expr)].append(name)
elif isinstance(leaf, torch.Tensor):
# pyrefly: ignore # bad-assignment
for i, dim in enumerate(leaf.shape):
if isinstance(dim, torch.SymInt):
src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]")

View File

@ -407,6 +407,7 @@ class MethodDispatcher(Dispatcher):
Dispatcher
"""
# pyrefly: ignore # bad-override
__slots__ = ("obj", "cls")
@classmethod

View File

@ -120,7 +120,7 @@ def _torchscript_schema_to_signature_impl(
# which makes it hard to do type annotation
kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment]
# This renders all previous arguments to positional only
# pyrefly: ignore # bad-assignment
for idx, p in enumerate(parameters):
assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
parameters[idx] = Parameter(
@ -129,7 +129,7 @@ def _torchscript_schema_to_signature_impl(
default=p.default,
annotation=p.annotation,
)
# pyrefly: ignore # missing-attribute
parameters.append(
Parameter(name=name, kind=kind, default=default, annotation=arg_type)
)
@ -143,7 +143,6 @@ def _torchscript_schema_to_signature_impl(
else:
return_type = tuple(return_types)
# pyrefly: ignore # bad-argument-type
return inspect.Signature(parameters, return_annotation=return_type)

View File

@ -241,7 +241,6 @@ def tensorify_python_scalars(
# pyrefly: ignore # missing-attribute
val = node.meta.get("val")
if isinstance(val, FakeTensor):
# pyrefly: ignore # bad-assignment
for dim in val.shape:
if isinstance(dim, torch.SymInt):
for s in dim.node.expr.free_symbols:
@ -277,6 +276,7 @@ def tensorify_python_scalars(
):
transform = True
try:
# pyrefly: ignore # unbound-name
proxy = _sympy_interp(zf.node.expr)
except NotImplementedError:
transform = False
@ -303,6 +303,7 @@ def tensorify_python_scalars(
args.append(a)
if transform:
# pyrefly: ignore # unbound-name
replacement_proxy = replacement_op(*args)
# pyrefly: ignore # missing-attribute

View File

@ -93,6 +93,7 @@ class FakeTensorProp(torch.fx.Interpreter):
if (shape_env := self._mode.shape_env) and (
symbol_to_path := compute_unbacked_bindings(shape_env, result)
):
# pyrefly: ignore # unbound-name
n.meta["unbacked_bindings"] = symbol_to_path
return result

View File

@ -274,7 +274,6 @@ class PassManager:
logger.debug("Running pass '%s'", fn_name)
try:
# pyrefly: ignore # not-callable
res = fn(module)
if not isinstance(res, PassResult) and not hasattr(

View File

@ -395,21 +395,25 @@ class _MinimizerBase:
report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
if self.module_exporter:
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
result_key = result_key[-1]
# If the result is still a tuple (happens in non-sequential mode),
# we only use the first element as name.
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
result_key = str(result_key[0])
# pyre-ignore[29]: not a function
self.module_exporter(
a_input,
submodule,
# pyrefly: ignore # unbound-name
result_key + "_cpu",
)
# pyre-ignore[29]: not a function
self.module_exporter(
b_input,
submodule,
# pyrefly: ignore # unbound-name
result_key + "_acc",
)
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]

View File

@ -298,10 +298,14 @@ def insert_deferred_runtime_asserts(
and s not in expr_to_proxy
):
with _set_node_metadata_hook(gm, _node_metadata_hook):
# pyrefly: ignore # unbound-name
expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
# pyrefly: ignore # unbound-name
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
# pyrefly: ignore # unbound-name
match_symbol(example_value, lambda: node)
# pyrefly: ignore # unbound-name
if isinstance(t := example_value, torch.Tensor):
for i, s in enumerate(t.size()):
match_symbol(
@ -382,6 +386,7 @@ def insert_deferred_runtime_asserts(
# maybe re-reify expression, replace current node
if (
# pyrefly: ignore # unbound-name
sym_expr in expr_to_proxy
or ( # example value is redundant
_is_intermediate_tensor_sym_call(node)
@ -400,20 +405,30 @@ def insert_deferred_runtime_asserts(
nn_module_stack=node.meta.get("nn_module_stack"),
),
):
# pyrefly: ignore # unbound-name
expr_to_proxy[sym_expr] = _sympy_interp(
expr_to_proxy, sym_expr
expr_to_proxy,
# pyrefly: ignore # unbound-name
sym_expr,
) # type: ignore[arg-type]
# won't try DCE-ing tensor compute here
hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
node.replace_all_uses_with(hash_node)
gm.graph.erase_node(node)
log.debug(
"CSE node %s -> %s for expr %s", node, hash_node, sym_expr
"CSE node %s -> %s for expr %s",
node,
hash_node,
# pyrefly: ignore # unbound-name
sym_expr,
)
# store node in hash cons, don't delete/replace
# pyrefly: ignore # unbound-name
elif sym_expr not in expr_to_proxy and not isinstance(
sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
# pyrefly: ignore # unbound-name
sym_expr,
(sympy.Number, sympy.logic.boolalg.BooleanAtom),
): # don't hash cons primitives
expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer) # type: ignore[arg-type]

View File

@ -317,6 +317,7 @@ def split_module(
and isinstance(s0 := val.node.expr, sympy.Symbol)
and s0 not in symbol_to_node
):
# pyrefly: ignore # unbound-name
symbol_to_node[val.node.expr] = node
if node.op in ["placeholder", "get_attr", "output"]:

View File

@ -84,6 +84,7 @@ def get_source_partitions(
if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and (
torch_fn := node.meta.get("torch_fn", None)
) is not None:
# pyrefly: ignore # unbound-name
node_fqn, source_fn = torch_fn
source_fn_name = source_fn.split(".")[1]
if source_fn_name in wanted_sources:

View File

@ -288,7 +288,7 @@ def _replace_pattern(
elif isinstance(pattern, Graph):
pattern_graph = pattern
else:
pattern_graph = symbolic_trace(pattern).graph
pattern_graph = symbolic_trace(pattern).graph # type: ignore[arg-type]
matcher = SubgraphMatcher(
pattern_graph,
@ -321,7 +321,7 @@ def _replace_pattern(
assert replacement_callback is not None, (
"Must provide either a replacement GraphModule or a replacement callback"
)
common_replacement_graph = None
common_replacement_graph = None # type: ignore[assignment]
# As we progressively replace nodes, we'll need to keep track of how the match results should change
match_changed_node: dict[Node, Node] = {}

View File

@ -561,7 +561,6 @@ def cat(tensors: list[list[int]], dim: int):
for i in range(len(tensors)):
tensor = tensors[i]
if not should_skip(tensor):
# pyrefly: ignore # bad-argument-type
check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
cat_dim_size = cat_dim_size + tensor[dim]

View File

@ -128,6 +128,7 @@ def _format_model_info(model_info: ModelInfo) -> str:
target_to_messages = {}
for node, message in model_info.dispatch_failures:
if str(node.target) not in target_to_messages:
# pyrefly: ignore # unsupported-operation
target_to_messages[str(node.target)] = message
for target, nodes in sorted(

View File

@ -149,6 +149,7 @@ class ElementwiseTypePromotionRule(TypePromotionRule):
f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})"
)
# pyrefly: ignore # bad-override
def __eq__(self, other: object, /) -> bool:
if not isinstance(other, ElementwiseTypePromotionRule):
return False
@ -265,6 +266,7 @@ class ReductionTypePromotionRule(TypePromotionRule):
def __repr__(self):
return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})"
# pyrefly: ignore # bad-override
def __eq__(self, other: object, /) -> bool:
if not isinstance(other, ElementwiseTypePromotionRule):
return False

View File

@ -298,9 +298,12 @@ def _create_node(
for key, value in sorted(attributes.items()):
if key in _SKIP_NODE_ATTRIBUTES:
continue
# pyrefly: ignore # unbound-name
_add_attribute(node, key, value, aten=aten)
if shape_inference:
# pyrefly: ignore # unbound-name
_C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
# pyrefly: ignore # unbound-name
return node

View File

@ -219,7 +219,6 @@ def index_put(
if len(indices_list) > 1:
for idx_ in range(len(indices_list)):
if symbolic_helper._is_bool(indices_list[idx_]):
# pyrefly: ignore # unsupported-operation
indices_list[idx_] = g.op("NonZero", indices_list[idx_])
index = indices_list[0]

View File

@ -698,7 +698,6 @@ def _multi_tensor_adam(
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
)
# pyrefly: ignore # no-matching-overload
torch._foreach_mul_(device_exp_avg_sqs, beta2)
# Due to the strictness of the _foreach_addcmul API, we can't have a single

View File

@ -115,11 +115,14 @@ def _strong_wolfe(
t = _cubic_interpolate(
# pyrefly: ignore # index-error
bracket[0],
# pyrefly: ignore # unbound-name
bracket_f[0],
bracket_gtd[0], # type: ignore[possibly-undefined]
# pyrefly: ignore # index-error
bracket[1],
# pyrefly: ignore # unbound-name
bracket_f[1],
# pyrefly: ignore # unbound-name
bracket_gtd[1],
)
@ -130,14 +133,20 @@ def _strong_wolfe(
# + `t` is at one of the boundary,
# we will move `t` to a position which is `0.1 * len(bracket)`
# away from the nearest boundary point.
# pyrefly: ignore # unbound-name
eps = 0.1 * (max(bracket) - min(bracket))
# pyrefly: ignore # unbound-name
if min(max(bracket) - t, t - min(bracket)) < eps:
# interpolation close to boundary
# pyrefly: ignore # unbound-name
if insuf_progress or t >= max(bracket) or t <= min(bracket):
# evaluate at 0.1 away from boundary
# pyrefly: ignore # unbound-name
if abs(t - max(bracket)) < abs(t - min(bracket)):
# pyrefly: ignore # unbound-name
t = max(bracket) - eps
else:
# pyrefly: ignore # unbound-name
t = min(bracket) + eps
insuf_progress = False
else:
@ -151,13 +160,17 @@ def _strong_wolfe(
gtd_new = g_new.dot(d)
ls_iter += 1
# pyrefly: ignore # unbound-name
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
# Armijo condition not satisfied or not lower than lowest point
# pyrefly: ignore # unsupported-operation
bracket[high_pos] = t
# pyrefly: ignore # unbound-name
bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
bracket_gtd[high_pos] = gtd_new
# pyrefly: ignore # unbound-name
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
else:
if abs(gtd_new) <= -c2 * gtd:
@ -168,19 +181,24 @@ def _strong_wolfe(
# old high becomes new low
# pyrefly: ignore # unsupported-operation
bracket[high_pos] = bracket[low_pos]
# pyrefly: ignore # unbound-name
bracket_f[high_pos] = bracket_f[low_pos]
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
bracket_gtd[high_pos] = bracket_gtd[low_pos]
# new point becomes new low
# pyrefly: ignore # unsupported-operation
bracket[low_pos] = t
# pyrefly: ignore # unbound-name
bracket_f[low_pos] = f_new
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
bracket_gtd[low_pos] = gtd_new
# return stuff
t = bracket[low_pos] # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
f_new = bracket_f[low_pos]
g_new = bracket_g[low_pos] # type: ignore[possibly-undefined]
return f_new, g_new, t, ls_func_evals

View File

@ -420,6 +420,7 @@ class LambdaLR(LRScheduler):
for idx, fn in enumerate(self.lr_lambdas):
if not isinstance(fn, types.FunctionType):
# pyrefly: ignore # unsupported-operation
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
return state_dict
@ -539,6 +540,7 @@ class MultiplicativeLR(LRScheduler):
for idx, fn in enumerate(self.lr_lambdas):
if not isinstance(fn, types.FunctionType):
# pyrefly: ignore # unsupported-operation
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
return state_dict
@ -1215,6 +1217,7 @@ class SequentialLR(LRScheduler):
state_dict["_schedulers"] = [None] * len(self._schedulers)
for idx, s in enumerate(self._schedulers):
# pyrefly: ignore # unsupported-operation
state_dict["_schedulers"][idx] = s.state_dict()
return state_dict
@ -1557,6 +1560,7 @@ class ChainedScheduler(LRScheduler):
state_dict["_schedulers"] = [None] * len(self._schedulers)
for idx, s in enumerate(self._schedulers):
# pyrefly: ignore # unsupported-operation
state_dict["_schedulers"][idx] = s.state_dict()
return state_dict

View File

@ -337,7 +337,6 @@ def _single_tensor_sgd(
if not torch.jit.is_scripting():
lr = _to_scalar(lr)
# pyrefly: ignore # bad-assignment
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
@ -433,12 +432,10 @@ def _multi_tensor_sgd(
all_states_with_momentum_buffer = True
for i in range(len(device_momentum_buffer_list)):
# pyrefly: ignore # index-error
if device_momentum_buffer_list[i] is None:
all_states_with_momentum_buffer = False
break
else:
# pyrefly: ignore # index-error
bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
if all_states_with_momentum_buffer:
@ -446,15 +443,13 @@ def _multi_tensor_sgd(
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
else:
bufs = []
# pyrefly: ignore # bad-assignment
for i in range(len(device_momentum_buffer_list)):
# pyrefly: ignore # index-error
if device_momentum_buffer_list[i] is None:
buf = device_momentum_buffer_list[i] = momentum_buffer_list[
indices[i]
] = device_grads[i].detach().clone()
else:
# pyrefly: ignore # index-error
buf = cast(Tensor, device_momentum_buffer_list[i])
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)

View File

@ -672,7 +672,7 @@ class MemoryProfile:
output: list[tuple[int, Action, KeyAndID, int]] = []
allocation_times: dict[tuple[TensorKey, bool], int] = {}
live_unknown: dict[tuple[int, torch.device], Literal[True]] = {}
# pyrefly: ignore # bad-assignment
for event in self._op_tree.dfs():
if event.typed[0] == _EventType.Allocation:
alloc_fields = event.typed[1]
@ -774,14 +774,12 @@ class MemoryProfile:
for key, (_, version) in node.inputs.items()
if self._categories.get(key, version)
in (Category.GRADIENT, Category.PARAMETER)
# pyrefly: ignore # unsupported-operation
or key.id in depends_on_gradient
)
if ids:
# pyrefly: ignore # missing-attribute
depends_on_gradient.update(ids)
# pyrefly: ignore # missing-attribute
depends_on_gradient.update(key.id for key in node.outputs)
# We are guaranteed to exit because there is a finite set of
@ -790,7 +788,6 @@ class MemoryProfile:
# once to fold the first step into that loop, and a third time
# where no new elements are added.
if len(depends_on_gradient) == start_size:
# pyrefly: ignore # bad-return
return depends_on_gradient
def _set_gradients_and_temporaries(self) -> None:

View File

@ -140,6 +140,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
if dense.dtype != torch.float:
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
@ -172,6 +173,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
meta_offsets = _calculate_meta_reordering_scatter_offsets(
m, meta_ncols, meta_dtype, device
)
# pyrefly: ignore # unbound-name
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
return (sparse, meta_reordered.view(m, meta_ncols))

View File

@ -385,7 +385,7 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
g1 = c_offsets[r + 1]
for g in range(g0, g1):
p, q = pq[g]
# pyrefly: ignore # unsupported-operation
accumulators[r] += blocks[p] @ others[q]
else:
_scatter_mm2(blocks, others, c_offsets, pq, accumulators)

View File

@ -1219,6 +1219,7 @@ def originate_pairs(
else:
for pair_type in pair_types:
try:
# pyrefly: ignore # bad-instantiation
return [pair_type(actual, expected, id=id, **options)]
# Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the
# inputs. Thus, we try the next pair type.

View File

@ -95,7 +95,7 @@ from torch.utils._import_utils import _check_module_exists
import torch.utils._pytree as pytree
from torch.utils import cpp_extension
try:
import pytest
import pytest # type: ignore[import-not-found]
has_pytest = True
except ImportError:
has_pytest = False

View File

@ -117,6 +117,7 @@ def context_decorator(ctx, func):
@functools.wraps(func)
def decorate_context(*args, **kwargs):
# pyrefly: ignore # bad-context-manager
with ctx_factory():
return func(*args, **kwargs)

View File

@ -41,7 +41,10 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
)
# pyrefly: ignore # import-error
import optree
# pyrefly: ignore # import-error
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
@ -706,6 +709,7 @@ def tree_map_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -766,6 +770,7 @@ def tree_map_only_(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -1079,6 +1084,7 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
with python_pytree._NODE_REGISTRY_LOCK:
# pyrefly: ignore # bad-assignment
python_pytree._cxx_pytree_imported = True
args, kwargs = (), {} # type: ignore[var-annotated]
for args, kwargs in python_pytree._cxx_pytree_pending_imports:

View File

@ -152,6 +152,7 @@ class DebugMode(TorchDispatchMode):
super().__enter__()
return self
# pyrefly: ignore # bad-override
def __exit__(self, *args):
super().__exit__(*args)
if self.record_torchfunction:

View File

@ -60,6 +60,7 @@ def _device_constructors():
# NB: This is directly called from C++ in torch/csrc/Device.cpp
class DeviceContext(TorchFunctionMode):
def __init__(self, device):
# pyrefly: ignore # read-only
self.device = torch.device(device)
def __enter__(self):

View File

@ -35,10 +35,12 @@ def cache_method(
if not (cache := getattr(self, cache_name, None)):
cache = {}
setattr(self, cache_name, cache)
# pyrefly: ignore # unbound-name
cached_value = cache.get(args, _cache_sentinel)
if cached_value is not _cache_sentinel:
return cached_value
value = f(self, *args, **kwargs)
# pyrefly: ignore # unbound-name
cache[args] = value
return value

View File

@ -158,6 +158,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
# MutableSet impl will iterate over other, iter over smaller of two sets
if isinstance(other, OrderedSet) and len(self) < len(other):
# pyrefly: ignore # unsupported-operation, bad-return
return other & self
return cast(OrderedSet[T], super().__and__(other))

View File

@ -708,6 +708,7 @@ class structseq(tuple[_T_co, ...]):
def __new__(
cls: type[Self],
sequence: Iterable[_T_co],
# pyrefly: ignore # bad-function-definition
dict: dict[str, Any] = ...,
) -> Self:
raise NotImplementedError
@ -754,6 +755,7 @@ def _tuple_flatten_with_keys(
d: tuple[T, ...],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _tuple_flatten(d)
# pyrefly: ignore # bad-return
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -767,6 +769,7 @@ def _list_flatten(d: list[T]) -> tuple[list[T], Context]:
def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _list_flatten(d)
# pyrefly: ignore # bad-return
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -782,6 +785,7 @@ def _dict_flatten_with_keys(
d: dict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _dict_flatten(d)
# pyrefly: ignore # bad-return
return [(MappingKey(k), v) for k, v in zip(context, values)], context
@ -797,6 +801,7 @@ def _namedtuple_flatten_with_keys(
d: NamedTuple,
) -> tuple[list[tuple[KeyEntry, Any]], Context]:
values, context = _namedtuple_flatten(d)
# pyrefly: ignore # bad-return
return (
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
context,
@ -846,6 +851,7 @@ def _ordereddict_flatten_with_keys(
d: OrderedDict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _ordereddict_flatten(d)
# pyrefly: ignore # bad-return
return [(MappingKey(k), v) for k, v in zip(context, values)], context
@ -870,6 +876,7 @@ def _defaultdict_flatten_with_keys(
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _defaultdict_flatten(d)
_, dict_context = context
# pyrefly: ignore # bad-return
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
@ -918,6 +925,7 @@ def _deque_flatten_with_keys(
d: deque[T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _deque_flatten(d)
# pyrefly: ignore # bad-return
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -1547,6 +1555,7 @@ def tree_map_only(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -1607,6 +1616,7 @@ def tree_map_only_(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -1819,6 +1829,7 @@ def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
for attr in classname.split("."):
enum_cls = getattr(enum_cls, attr)
enum_cls = cast(type[Enum], enum_cls)
# pyrefly: ignore # unsupported-operation
return enum_cls[obj["name"]]
return obj

View File

@ -305,6 +305,7 @@ def strobelight(
) -> Callable[_P, Optional[_R]]:
@functools.wraps(work_function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
# pyrefly: ignore # bad-argument-type
return profiler.profile(work_function, *args, **kwargs)
return wrapper_function

View File

@ -105,6 +105,7 @@ def _keep_float(
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
@functools.wraps(f)
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
# pyrefly: ignore # bad-argument-type
r: Union[_T, sympy.Float] = f(*args)
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
r, sympy.Float
@ -112,6 +113,7 @@ def _keep_float(
r = sympy.Float(float(r))
return r
# pyrefly: ignore # bad-return
return inner
@ -198,10 +200,12 @@ class FloorDiv(sympy.Function):
@property
def base(self) -> sympy.Basic:
# pyrefly: ignore # missing-attribute
return self.args[0]
@property
def divisor(self) -> sympy.Basic:
# pyrefly: ignore # missing-attribute
return self.args[1]
def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
@ -370,6 +374,7 @@ class ModularIndexing(sympy.Function):
return None
def _eval_is_nonnegative(self) -> Optional[bool]:
# pyrefly: ignore # missing-attribute
p, q = self.args[:2]
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
@ -450,6 +455,7 @@ class PythonMod(sympy.Function):
# - floor(p / q) = 0
# - p % q = p - floor(p / q) * q = p
less = p < q
# pyrefly: ignore # missing-attribute
if less.is_Boolean and bool(less) and r.is_positive:
return p
@ -466,8 +472,11 @@ class PythonMod(sympy.Function):
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
def _ccode(self, printer):
# pyrefly: ignore # missing-attribute
p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
# pyrefly: ignore # missing-attribute
q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
# pyrefly: ignore # missing-attribute
abs_q = str(q) if self.args[1].is_positive else f"abs({q})"
return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}"
@ -548,6 +557,7 @@ class CeilToInt(sympy.Function):
return sympy.Integer(math.ceil(float(number)))
def _ccode(self, printer):
# pyrefly: ignore # missing-attribute
number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5)
return f"ceil({number})"
@ -818,6 +828,7 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
if not cond:
return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
if isinstance(ai, cls):
# pyrefly: ignore # missing-attribute
return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
return a
@ -995,6 +1006,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
def _eval_is_negative(self): # type:ignore[override]
# pyrefly: ignore # missing-attribute
return fuzzy_and(a.is_negative for a in self.args)
@ -1013,6 +1025,7 @@ class Min(MinMaxBase, Application): # type: ignore[misc]
return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
def _eval_is_negative(self): # type:ignore[override]
# pyrefly: ignore # missing-attribute
return fuzzy_or(a.is_negative for a in self.args)
@ -1150,7 +1163,9 @@ class IntTrueDiv(sympy.Function):
return sympy.Float(int(base) / int(divisor))
def _ccode(self, printer):
# pyrefly: ignore # missing-attribute
base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
# pyrefly: ignore # missing-attribute
divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
return f"((int){base}/(int){divisor})"
@ -1310,9 +1325,11 @@ class Identity(sympy.Function):
precedence = 10
def __repr__(self): # type: ignore[override]
# pyrefly: ignore # missing-attribute
return f"Identity({self.args[0]})"
def _eval_is_real(self):
# pyrefly: ignore # missing-attribute
return self.args[0].is_real
def _eval_is_integer(self):
@ -1320,12 +1337,15 @@ class Identity(sympy.Function):
def _eval_expand_identity(self, **hints):
# Removes the identity op.
# pyrefly: ignore # missing-attribute
return self.args[0]
def __int__(self) -> int:
# pyrefly: ignore # missing-attribute
return int(self.args[0])
def __float__(self) -> float:
# pyrefly: ignore # missing-attribute
return float(self.args[0])

View File

@ -9,6 +9,7 @@ from sympy.core.parameters import global_parameters
from sympy.core.singleton import S, Singleton
# pyrefly: ignore # invalid-inheritance
class IntInfinity(Number, metaclass=Singleton):
r"""Positive integer infinite quantity.
@ -203,6 +204,7 @@ class IntInfinity(Number, metaclass=Singleton):
int_oo = S.IntInfinity
# pyrefly: ignore # invalid-inheritance
class NegativeIntInfinity(Number, metaclass=Singleton):
"""Negative integer infinite quantity.

View File

@ -66,6 +66,7 @@ class ExprPrinter(StrPrinter):
# NB: this pow by natural, you should never have used builtin sympy.pow
# for FloatPow, and a symbolic exponent should be PowByNatural. These
# means exp is guaranteed to be integer.
# pyrefly: ignore # bad-override
def _print_Pow(self, expr: sympy.Expr) -> str:
base, exp = expr.args
assert exp == int(exp), exp

View File

@ -175,6 +175,7 @@ class ReferenceAnalysis:
@staticmethod
def pow(a, b):
# pyrefly: ignore # bad-argument-type
return _keep_float(FloatPow)(a, b)
@staticmethod

View File

@ -123,7 +123,9 @@ AllFn2 = Union[ExprFn2, BoolFn2]
class ValueRanges(Generic[_T]):
if TYPE_CHECKING:
# ruff doesn't understand circular references but mypy does
# pyrefly: ignore # unbound-name
ExprVR = ValueRanges[sympy.Expr] # noqa: F821
# pyrefly: ignore # unbound-name
BoolVR = ValueRanges[SympyBoolean] # noqa: F821
AllVR = Union[ExprVR, BoolVR]
@ -464,6 +466,7 @@ class SymPyValueRangeAnalysis:
@staticmethod
def to_dtype(a, dtype, src_dtype=None):
if dtype == torch.float64:
# pyrefly: ignore # bad-argument-type
return ValueRanges.increasing_map(a, ToFloat)
elif dtype == torch.bool:
return ValueRanges.unknown_bool()
@ -473,6 +476,7 @@ class SymPyValueRangeAnalysis:
@staticmethod
def trunc_to_int(a, dtype):
# pyrefly: ignore # bad-argument-type
return ValueRanges.increasing_map(a, TruncToInt)
@staticmethod
@ -621,7 +625,10 @@ class SymPyValueRangeAnalysis:
return ValueRanges.unknown()
else:
return ValueRanges.coordinatewise_monotone_map(
a, b, _keep_float(IntTrueDiv)
a,
b,
# pyrefly: ignore # bad-argument-type
_keep_float(IntTrueDiv),
)
@staticmethod
@ -634,7 +641,10 @@ class SymPyValueRangeAnalysis:
return ValueRanges.unknown()
else:
return ValueRanges.coordinatewise_monotone_map(
a, b, _keep_float(FloatTrueDiv)
a,
b,
# pyrefly: ignore # bad-argument-type
_keep_float(FloatTrueDiv),
)
@staticmethod
@ -713,6 +723,7 @@ class SymPyValueRangeAnalysis:
# We should know that b >= 0 but we may have forgotten this fact due
# to replacements, so don't assert it, but DO clamp it to prevent
# degenerate problems
# pyrefly: ignore # no-matching-overload
return ValueRanges.coordinatewise_increasing_map(
a, b & ValueRanges(0, int_oo), PowByNatural
)
@ -879,6 +890,7 @@ class SymPyValueRangeAnalysis:
@classmethod
def round_to_int(cls, number, dtype):
# pyrefly: ignore # bad-argument-type
return ValueRanges.increasing_map(number, RoundToInt)
# It's used in some models on symints
@ -992,6 +1004,7 @@ class SymPyValueRangeAnalysis:
@staticmethod
def trunc(x):
# pyrefly: ignore # bad-argument-type
return ValueRanges.increasing_map(x, TruncToFloat)

View File

@ -202,6 +202,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
Args:
device (int, optional): if specified, all parameters will be copied to that device
"""
# pyrefly: ignore # missing-attribute
return self._apply(lambda t: getattr(t, custom_backend_name)(device))
_check_register_once(torch.nn.Module, custom_backend_name)

View File

@ -63,6 +63,7 @@ def generate_coo_data(size, sparse_dim, nnz, dtype, device):
indices = torch.rand(sparse_dim, nnz, device=device)
indices.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(indices))
indices = indices.to(torch.long)
# pyrefly: ignore # no-matching-overload
values = torch.rand([nnz, ], dtype=dtype, device=device)
return indices, values

View File

@ -15,6 +15,7 @@ _warned_tensor_cores = False
_default_float_32_precision = torch.get_float32_matmul_precision()
try:
from tabulate import tabulate
HAS_TABULATE = True
@ -169,6 +170,7 @@ if HAS_TABULATE:
_disable_tensor_cores()
table.append([
("Training" if optimizer else "Inference"),
# pyrefly: ignore # redundant-condition
backend if backend else "-",
mode if mode is not None else "-",
f"{compilation_time} ms " if compilation_time else "-",
@ -189,4 +191,5 @@ if HAS_TABULATE:
])
# pyrefly: ignore # not-callable
return tabulate(table, headers=field_names, tablefmt="github")

View File

@ -35,6 +35,7 @@ def _get_build_root() -> str:
global _BUILD_ROOT
if _BUILD_ROOT is None:
_BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build")
# pyrefly: ignore # missing-argument
atexit.register(shutil.rmtree, _BUILD_ROOT)
return _BUILD_ROOT

View File

@ -91,6 +91,7 @@ class FuzzedSparseTensor(FuzzedTensor):
return x
def _make_tensor(self, params, state):
# pyrefly: ignore # missing-attribute
size, _, _ = self._get_size_and_steps(params)
density = params['density']
nnz = math.ceil(sum(size) * density)
@ -99,8 +100,10 @@ class FuzzedSparseTensor(FuzzedTensor):
is_coalesced = params['coalesced']
sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size)
sparse_dim = min(sparse_dim, len(size))
# pyrefly: ignore # missing-attribute
tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced)
# pyrefly: ignore # missing-attribute
if self._cuda:
tensor = tensor.cuda()
sparse_dim = tensor.sparse_dim()
@ -116,6 +119,7 @@ class FuzzedSparseTensor(FuzzedTensor):
"sparse_dim": sparse_dim,
"dense_dim": dense_dim,
"is_hybrid": is_hybrid,
# pyrefly: ignore # missing-attribute
"dtype": str(self._dtype),
}
return tensor, properties

View File

@ -233,6 +233,7 @@ class Timer:
setup = textwrap.dedent(setup)
setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip()
# pyrefly: ignore # bad-instantiation
self._timer = self._timer_cls(
stmt=stmt,
setup=setup,

View File

@ -448,11 +448,13 @@ class GlobalsBridge:
load_lines = []
for name, wrapped_value in self._globals.items():
if wrapped_value.setup is not None:
# pyrefly: ignore # bad-argument-type
load_lines.append(textwrap.dedent(wrapped_value.setup))
if wrapped_value.serialization == Serialization.PICKLE:
path = os.path.join(self._data_dir, f"{name}.pkl")
load_lines.append(
# pyrefly: ignore # bad-argument-type
f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)")
with open(path, "wb") as f:
pickle.dump(wrapped_value.value, f)
@ -462,11 +464,13 @@ class GlobalsBridge:
# TODO: Figure out if we can use torch.serialization.add_safe_globals here
# Using weights_only=False after the change in
# https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573
# pyrefly: ignore # bad-argument-type
load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)")
torch.save(wrapped_value.value, path)
elif wrapped_value.serialization == Serialization.TORCH_JIT:
path = os.path.join(self._data_dir, f"{name}.pt")
# pyrefly: ignore # bad-argument-type
load_lines.append(f"{name} = torch.jit.load({repr(path)})")
with open(path, "wb") as f:
torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call]

View File

@ -222,6 +222,7 @@ def _get_autocast_kwargs(device_type="cuda"):
class CheckpointFunction(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
@ -784,6 +785,7 @@ class _Holder:
class _NoopSaveInputs(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(*args):
return torch.empty((0,))
@ -1006,6 +1008,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
def logging_mode():
with LoggingTensorMode(), \
capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
# pyrefly: ignore # bad-assignment
self.logs, self.tbs = logs_and_tb
yield logs_and_tb
return logging_mode()

View File

@ -787,6 +787,7 @@ class BuildExtension(build_ext):
# Use absolute path for output_dir so that the object file paths
# (`objects`) get generated with absolute paths.
# pyrefly: ignore # no-matching-overload
output_dir = os.path.abspath(output_dir)
# See Note [Absolute include_dirs]
@ -977,6 +978,7 @@ class BuildExtension(build_ext):
is_standalone=False):
if not self.compiler.initialized:
self.compiler.initialize()
# pyrefly: ignore # no-matching-overload
output_dir = os.path.abspath(output_dir)
# Note [Absolute include_dirs]
@ -1528,6 +1530,7 @@ def include_paths(device_type: str = "cpu", torch_include_dirs=True) -> list[str
# Support CUDA_INC_PATH env variable supported by CMake files
if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \
cuda_inc_path != '/usr/include':
# pyrefly: ignore # unbound-name
paths.append(cuda_inc_path)
if CUDNN_HOME is not None:
paths.append(os.path.join(CUDNN_HOME, 'include'))
@ -2569,6 +2572,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]:
def _get_vc_env(vc_arch: str) -> dict[str, str]:
try:
from setuptools import distutils # type: ignore[attr-defined]
# pyrefly: ignore # missing-attribute
return distutils._msvccompiler._get_vc_env(vc_arch)
except AttributeError:
try:

View File

@ -204,6 +204,7 @@ def collate(
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
# pyrefly: ignore # not-iterable
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError("each element in list of batch should be of equal size")
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.

View File

@ -70,6 +70,7 @@ def pin_memory(data, device=None):
return clone
else:
return type(data)(
# pyrefly: ignore # bad-argument-count
{k: pin_memory(sample, device) for k, sample in data.items()}
) # type: ignore[call-arg]
except TypeError:

View File

@ -674,6 +674,7 @@ class _BaseDataLoaderIter:
# Set pin memory device based on the current accelerator.
self._pin_memory_device = (
# pyrefly: ignore # unbound-name
acc.type
if self._pin_memory
and (acc := torch.accelerator.current_accelerator()) is not None

View File

@ -265,6 +265,7 @@ class _DataPipeType:
# Default type for DataPipe without annotation
_T_co = TypeVar("_T_co", covariant=True)
# pyrefly: ignore # invalid-annotation
_DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
@ -283,6 +284,7 @@ class _DataPipeMeta(GenericMeta):
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
# TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
# pyrefly: ignore # no-access
cls.__origin__ = None
if "type" in namespace:
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]

View File

@ -80,6 +80,7 @@ class Capture:
def _ops_str(self):
res = ""
# pyrefly: ignore # not-iterable
for op in self.ctx["operations"]:
if len(res) > 0:
res += "\n"
@ -89,6 +90,7 @@ class Capture:
def __getstate__(self):
# TODO(VitalyFedyunin): Currently can't pickle (why?)
self.ctx["schema_df"] = None
# pyrefly: ignore # not-iterable
for var in self.ctx["variables"]:
var.calculated_value = None
state = {}
@ -112,11 +114,13 @@ class Capture:
return CaptureGetItem(self, key, ctx=self.ctx)
def __setitem__(self, key, value):
# pyrefly: ignore # missing-attribute
self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
def __add__(self, add_val):
res = CaptureAdd(self, add_val, ctx=self.ctx)
var = CaptureVariable(res, ctx=self.ctx)
# pyrefly: ignore # missing-attribute
self.ctx["operations"].append(
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
)
@ -125,6 +129,7 @@ class Capture:
def __sub__(self, add_val):
res = CaptureSub(self, add_val, ctx=self.ctx)
var = CaptureVariable(res, ctx=self.ctx)
# pyrefly: ignore # missing-attribute
self.ctx["operations"].append(
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
)
@ -134,15 +139,19 @@ class Capture:
res = CaptureMul(self, add_val, ctx=self.ctx)
var = CaptureVariable(res, ctx=self.ctx)
t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
# pyrefly: ignore # missing-attribute
self.ctx["operations"].append(t)
return var
def _is_context_empty(self):
# pyrefly: ignore # bad-argument-type
return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
def apply_ops_2(self, dataframe):
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
# pyrefly: ignore # unsupported-operation
self.ctx["variables"][0].calculated_value = dataframe
# pyrefly: ignore # not-iterable
for op in self.ctx["operations"]:
op.execute()
@ -175,6 +184,7 @@ class Capture:
res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
var = CaptureVariable(None, ctx=self.ctx)
t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
# pyrefly: ignore # missing-attribute
self.ctx["operations"].append(t)
return var
@ -273,7 +283,9 @@ class CaptureVariable(Capture):
def apply_ops(self, dataframe):
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
# pyrefly: ignore # unsupported-operation
self.ctx["variables"][0].calculated_value = dataframe
# pyrefly: ignore # not-iterable
for op in self.ctx["operations"]:
op.execute()
return self.calculated_value
@ -373,6 +385,7 @@ def get_val(capture):
class CaptureInitial(CaptureVariable):
def __init__(self, schema_df=None):
# pyrefly: ignore # bad-assignment
new_ctx: dict[str, list[Any]] = {
"operations": [],
"variables": [],
@ -388,6 +401,7 @@ class CaptureDataFrame(CaptureInitial):
class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
def as_datapipe(self):
# pyrefly: ignore # unsupported-operation
return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)
def raw_iterator(self):

View File

@ -92,6 +92,7 @@ class FilterDataFramesPipe(DFIterDataPipe):
size = None
all_buffer = []
filter_res = []
# pyrefly: ignore # bad-assignment
for df in self.source_datapipe:
if size is None:
size = len(df.index)

View File

@ -135,6 +135,7 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
_fast_forward_iterator: Optional[Iterator] = None
def __iter__(self) -> Iterator[_T_co]:
# pyrefly: ignore # bad-return
return self
def __getattr__(self, attribute_name):
@ -379,6 +380,7 @@ class _DataPipeSerializationWrapper:
value = pickle.dumps(self._datapipe)
except Exception:
if HAS_DILL:
# pyrefly: ignore # missing-attribute
value = dill.dumps(self._datapipe)
use_dill = True
else:
@ -388,6 +390,7 @@ class _DataPipeSerializationWrapper:
def __setstate__(self, state):
value, use_dill = state
if use_dill:
# pyrefly: ignore # missing-attribute
self._datapipe = dill.loads(value)
else:
self._datapipe = pickle.loads(value)
@ -404,6 +407,7 @@ class _DataPipeSerializationWrapper:
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
def __init__(self, datapipe: IterDataPipe[_T_co]):
super().__init__(datapipe)
# pyrefly: ignore # invalid-type-var
self._datapipe_iter: Optional[Iterator[_T_co]] = None
def __iter__(self) -> "_IterDataPipeSerializationWrapper":

View File

@ -118,6 +118,7 @@ class MapperIterDataPipe(IterDataPipe[_T_co]):
for idx in sorted(self.input_col[1:], reverse=True):
del data[idx]
else:
# pyrefly: ignore # unsupported-operation
data[self.input_col] = res
else:
if self.output_col == -1:

View File

@ -42,6 +42,7 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]):
"Sampler class requires input datapipe implemented `__len__`"
)
super().__init__()
# pyrefly: ignore # bad-assignment
self.datapipe = datapipe
self.sampler_args = () if sampler_args is None else sampler_args
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs

Some files were not shown because too many files have changed in this diff Show More