mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pyrefly suppressions 6/n (#164877)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (5,064 ignored) Only four directories left to enable Pull Request resolved: https://github.com/pytorch/pytorch/pull/164877 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad7b2bebc6
commit
086dec3235
@ -24,12 +24,12 @@ project-excludes = [
|
||||
"torch/distributed/**",
|
||||
"torch/nn/**",
|
||||
"torch/_dynamo/**",
|
||||
"torch/utils/**",
|
||||
# formatting issues
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
"torch/package/_package_pickler.py",
|
||||
"torch/jit/annotations.py",
|
||||
"torch/utils/data/datapipes/_typing.py",
|
||||
# ====
|
||||
"benchmarks/instruction_counts/main.py",
|
||||
"benchmarks/instruction_counts/definitions/setup.py",
|
||||
|
@ -59,6 +59,7 @@ class TestBundledInputs(TestCase):
|
||||
# despite having nominally large bundled inputs.
|
||||
augmented_size = model_size(sm)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertLess(augmented_size, original_size + (1 << 12))
|
||||
|
||||
loaded = save_and_load(sm)
|
||||
@ -66,12 +67,15 @@ class TestBundledInputs(TestCase):
|
||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||
self.assertEqual(len(inflated), len(samples))
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||
|
||||
for idx, inp in enumerate(inflated):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertIsInstance(inp, tuple)
|
||||
self.assertEqual(len(inp), 1)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertIsInstance(inp[0], torch.Tensor)
|
||||
if idx != 5:
|
||||
# Strides might be important for benchmarking.
|
||||
@ -140,6 +144,7 @@ class TestBundledInputs(TestCase):
|
||||
inflated = loaded.get_all_bundled_inputs()
|
||||
self.assertEqual(inflated, samples)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(loaded(*inflated[0]) == "first 1")
|
||||
|
||||
def test_multiple_methods_with_inputs(self):
|
||||
@ -187,6 +192,7 @@ class TestBundledInputs(TestCase):
|
||||
|
||||
# Check running and size helpers
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||
|
||||
@ -420,6 +426,7 @@ class TestBundledInputs(TestCase):
|
||||
augmented_size = model_size(sm)
|
||||
# assert the size has not increased more than 8KB
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertLess(augmented_size, original_size + (1 << 13))
|
||||
|
||||
loaded = save_and_load(sm)
|
||||
|
@ -48,6 +48,7 @@ class TestComplexTensor(TestCase):
|
||||
def test_all(self, device, dtype):
|
||||
# issue: https://github.com/pytorch/pytorch/issues/120875
|
||||
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(torch.all(x))
|
||||
|
||||
@dtypes(*complex_types())
|
||||
@ -56,6 +57,7 @@ class TestComplexTensor(TestCase):
|
||||
x = torch.tensor(
|
||||
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertFalse(torch.any(x))
|
||||
|
||||
@onlyCPU
|
||||
|
@ -142,6 +142,7 @@ class TestTypeHints(TestCase):
|
||||
]
|
||||
)
|
||||
if result != 0:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.fail(f"mypy failed:\n{stderr}\n{stdout}")
|
||||
|
||||
|
||||
|
@ -125,6 +125,7 @@ class TestDTypeInfo(TestCase):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||
# If reference count is leaked this would be a set of 10 elements
|
||||
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertLess(len(ref_cnt), 3)
|
||||
|
||||
self.assertEqual(torch.float64.to_complex(), torch.complex128)
|
||||
@ -135,6 +136,7 @@ class TestDTypeInfo(TestCase):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||
# If reference count is leaked this would be a set of 10 elements
|
||||
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertLess(len(ref_cnt), 3)
|
||||
|
||||
self.assertEqual(torch.complex128.to_real(), torch.double)
|
||||
|
@ -240,7 +240,7 @@ def get_decompositions(
|
||||
|
||||
registry = global_decomposition_table[type]
|
||||
packets_to_overloads = defaultdict(list)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for opo in registry:
|
||||
if isinstance(opo, (OpOverload, OpOverloadPacket)):
|
||||
packets_to_overloads[opo.overloadpacket].append(opo)
|
||||
|
@ -4079,6 +4079,7 @@ def _nll_loss_forward(
|
||||
return result, total_weight
|
||||
|
||||
if weight is not None:
|
||||
# pyrefly: ignore # unbound-name
|
||||
w = w.expand(self.shape)
|
||||
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
|
||||
wsum = torch.where(target != ignore_index, wsum, 0)
|
||||
|
@ -341,8 +341,8 @@ class MtiaInterface(DeviceInterface):
|
||||
synchronize = staticmethod(torch.mtia.synchronize)
|
||||
get_device_properties = staticmethod(torch.mtia.get_device_properties) # type: ignore[assignment]
|
||||
get_raw_stream = staticmethod(get_mtia_stream) # type: ignore[assignment, arg-type]
|
||||
exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type]
|
||||
maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type]
|
||||
exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type, has-type]
|
||||
maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type, has-type]
|
||||
memory_allocated = staticmethod(torch.mtia.memory_allocated) # type: ignore[assignment]
|
||||
is_bf16_supported = staticmethod(torch.mtia.is_bf16_supported) # type: ignore[arg-type]
|
||||
|
||||
@ -414,7 +414,7 @@ class XpuInterface(DeviceInterface):
|
||||
|
||||
current_device = staticmethod(torch.xpu.current_device)
|
||||
set_device = staticmethod(torch.xpu.set_device)
|
||||
device_count = staticmethod(torch.xpu.device_count)
|
||||
device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type]
|
||||
stream = staticmethod(torch.xpu.stream) # type: ignore[assignment]
|
||||
current_stream = staticmethod(torch.xpu.current_stream)
|
||||
set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment]
|
||||
@ -422,8 +422,8 @@ class XpuInterface(DeviceInterface):
|
||||
synchronize = staticmethod(torch.xpu.synchronize)
|
||||
get_device_properties = staticmethod(torch.xpu.get_device_properties) # type: ignore[assignment]
|
||||
get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[assignment, arg-type]
|
||||
exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type]
|
||||
maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type]
|
||||
exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type, has-type]
|
||||
maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type, has-type]
|
||||
memory_allocated = staticmethod(torch.xpu.memory_allocated)
|
||||
|
||||
# Can be mock patched by @patch decorator.
|
||||
|
@ -1097,7 +1097,6 @@ class TS2FXGraphConverter:
|
||||
|
||||
# Update the value of loop local variables.
|
||||
if node.outputsSize() >= 1:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i, outp in enumerate(node.outputs()):
|
||||
output_name = outp.debugName()
|
||||
self.name_to_node[output_name] = self.fx_graph.call_function(
|
||||
@ -1110,7 +1109,7 @@ class TS2FXGraphConverter:
|
||||
fx_block_args[i] = self.name_to_node[output_name]
|
||||
|
||||
# Update the value of global variables, whose values are modified inplace.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for i, name in enumerate(
|
||||
subgraph_converter.name_update_from_subblock_to_parent
|
||||
):
|
||||
|
@ -140,7 +140,7 @@ def key_path_to_source(
|
||||
source: Source = LocalSource("args")
|
||||
else:
|
||||
source, kp = sourced_prefixes.get(kp)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for k in kp:
|
||||
if isinstance(k, SequenceKey):
|
||||
source = GetItemSource(source, k.idx)
|
||||
|
@ -317,6 +317,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
)
|
||||
res_proxy.node.meta.update(meta.data)
|
||||
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
|
||||
# pyrefly: ignore # unbound-name
|
||||
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
|
||||
res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
|
||||
self.tracer.set_metadata(res_proxy.node, res_data)
|
||||
|
@ -83,7 +83,6 @@ def _node_metadata_hook(
|
||||
node.meta["torch_fn"] = node.meta.get(
|
||||
"torch_fn",
|
||||
(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f"{node.target.__name__}_0",
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f"{node.target.__class__.__name__}.{node.target.__name__}",
|
||||
|
@ -646,6 +646,7 @@ def update_schema():
|
||||
assert thrift_content[1].startswith("// checksum<<")
|
||||
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
from yaml import load, Loader
|
||||
|
||||
dst = load(content, Loader=Loader)
|
||||
|
@ -2183,6 +2183,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
simplify=True,
|
||||
)
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
node.meta["unbacked_bindings"] = unbacked_bindings
|
||||
|
||||
assert len(self.unbacked_symbols) == 0
|
||||
|
@ -471,7 +471,6 @@ def _check_input_constraints_for_graph(
|
||||
elif isinstance(node_val, torch.SymInt):
|
||||
_check_symint(
|
||||
node_val,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
arg,
|
||||
range_constraints,
|
||||
unification_map,
|
||||
|
@ -204,6 +204,7 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
suppress_pending = contextlib.nullcontext()
|
||||
fake_mode = detect_fake_mode()
|
||||
if fake_mode and (shape_env := fake_mode.shape_env):
|
||||
# pyrefly: ignore # unbound-name
|
||||
suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
|
||||
with disable_above, mode, suppress_pending:
|
||||
# precondition: The passed in function already handles unflattening inputs + flattening outputs
|
||||
|
@ -439,7 +439,7 @@ def collect_fw_donated_buffer_idxs(
|
||||
"""
|
||||
|
||||
storage_refs = set()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for t in itertools.chain(fw_ins, user_fw_outs, bw_outs):
|
||||
# Only access storage if a tensor has storage (not sparse)
|
||||
if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t):
|
||||
|
@ -196,6 +196,7 @@ class MemoryFormatMeta:
|
||||
|
||||
if use_memory_format:
|
||||
return MemoryFormatMeta(
|
||||
# pyrefly: ignore # unbound-name
|
||||
memory_format=torch._prims_common.suggest_memory_format(t),
|
||||
)
|
||||
|
||||
@ -892,12 +893,15 @@ class GraphSignature:
|
||||
parameters_to_mutate = {}
|
||||
for output_name, mutation_name in outputs_to_mutations.items():
|
||||
if mutation_name in user_inputs:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
user_inputs_to_mutate[output_name] = mutation_name
|
||||
else:
|
||||
assert mutation_name in buffers or mutation_name in parameters
|
||||
if mutation_name in buffers:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
buffers_to_mutate[output_name] = mutation_name
|
||||
else:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
parameters_to_mutate[output_name] = mutation_name
|
||||
|
||||
start, stop = stop, stop + num_user_outputs
|
||||
|
@ -232,7 +232,6 @@ def unwrap_tensor_subclasses(
|
||||
|
||||
attrs, _ = t.__tensor_flatten__()
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for attr in attrs:
|
||||
inner_tensor = getattr(t, attr)
|
||||
n_desc: Any = (
|
||||
@ -314,7 +313,6 @@ def runtime_unwrap_tensor_subclasses(
|
||||
|
||||
for idx, x in enumerate(wrapped_args):
|
||||
if not is_traceable_wrapper_subclass(x):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
xs_inner.append(x)
|
||||
continue
|
||||
|
||||
|
@ -199,6 +199,7 @@ def _extract_graph_with_inputs_outputs(
|
||||
new_node = new_graph.placeholder(node.name)
|
||||
# Can't use node_copy here as we may be turning previous call_function into placeholders
|
||||
new_node.meta = node.meta
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
env[node] = new_node
|
||||
|
||||
for node in joint_graph.nodes:
|
||||
@ -227,8 +228,10 @@ def _extract_graph_with_inputs_outputs(
|
||||
if any(all_args):
|
||||
env[node] = InvalidNode # type: ignore[assignment]
|
||||
continue
|
||||
# pyrefly: ignore # unsupported-operation, bad-argument-type
|
||||
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
||||
elif node.op == "get_attr":
|
||||
# pyrefly: ignore # unsupported-operation, bad-argument-type
|
||||
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
||||
elif node.op == "output":
|
||||
pass
|
||||
@ -1403,12 +1406,14 @@ def functionalize_rng_ops(
|
||||
devices = OrderedSet(
|
||||
get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values()
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
devices.discard(torch.device("cpu"))
|
||||
# multiple cuda devices won't work with cudagraphs anyway,
|
||||
# fallback to non graphsafe rng checkpointing
|
||||
multi_cuda_devices = len(devices) > 1
|
||||
|
||||
# this changes numerics, so if fallback_random is set we will not use it
|
||||
# pyrefly: ignore # unbound-name
|
||||
ind_config = torch._inductor.config
|
||||
use_rng_graphsafe_rng_functionalization = (
|
||||
config.graphsafe_rng_functionalization
|
||||
@ -2840,6 +2845,7 @@ def min_cut_rematerialization_partition(
|
||||
node_info,
|
||||
memory_budget=memory_budget,
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
if config._sync_decision_cross_ranks:
|
||||
saved_values = _sync_decision_cross_ranks(joint_graph, saved_values)
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
|
@ -57,6 +57,7 @@ def _interleave(a, b, dim=0):
|
||||
|
||||
stacked = torch.stack([a, b], dim=dim + 1)
|
||||
interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1)
|
||||
# pyrefly: ignore # unbound-name
|
||||
if b_trunc:
|
||||
# TODO: find torch alternative for slice_along dim for torch.jit.script to work
|
||||
interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1)
|
||||
|
@ -746,6 +746,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
|
||||
and (shape_env := loop_count.node.shape_env)
|
||||
and loop_count in shape_env.pending_fresh_unbacked_symbols
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
shape_env.pending_fresh_unbacked_symbols.remove(loop_count)
|
||||
|
||||
# Even when body function is not executed, we clone and unsqueeze the input
|
||||
|
@ -198,6 +198,7 @@ def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str:
|
||||
to a file. The yaml string can be loaded back into an operator profile
|
||||
structure using `read_profiles_from_yaml`.
|
||||
"""
|
||||
# pyrefly: ignore # import-error
|
||||
import yaml
|
||||
|
||||
from torch._export.serde.serialize import (
|
||||
@ -262,6 +263,7 @@ def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]:
|
||||
"""
|
||||
Reads the yaml saved by `save_op_profiles` and returns the operator profiles.
|
||||
"""
|
||||
# pyrefly: ignore # import-error
|
||||
import yaml
|
||||
|
||||
from torch._export.serde.serialize import (
|
||||
|
@ -914,6 +914,7 @@ class TorchLogsFormatter(logging.Formatter):
|
||||
and (trace_id := torch._guards.CompileContext.current_trace_id())
|
||||
is not None
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
record.traceid = f" [{trace_id}]"
|
||||
|
||||
glog_level_to_abbr = {
|
||||
|
@ -113,7 +113,6 @@ def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:
|
||||
if len(a) != len(b):
|
||||
return False
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for x, y in zip(a, b):
|
||||
if allow_rhs_unbacked:
|
||||
if isinstance(y, torch.SymInt):
|
||||
|
@ -6682,7 +6682,7 @@ def _infer_scalar_type(obj):
|
||||
# double.
|
||||
if length == 0:
|
||||
return torch.get_default_dtype()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for i in range(length):
|
||||
cur_item = obj[i]
|
||||
# TODO: test this
|
||||
|
@ -106,13 +106,12 @@ def _resize_fft_input(
|
||||
if x_sizes[dims[i]] < sizes[i]:
|
||||
must_copy = True
|
||||
pad_idx = len(pad_amount) - 2 * dims[i] - 1
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
|
||||
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
|
||||
|
||||
if x_sizes[dims[i]] > sizes[i]:
|
||||
x = x.narrow(dims[i], 0, sizes[i])
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch.constant_pad_nd(x, pad_amount) if must_copy else x
|
||||
|
||||
|
||||
|
@ -1374,6 +1374,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
return self._stack
|
||||
|
||||
@count
|
||||
# pyrefly: ignore # bad-override
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
@ -2624,6 +2625,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
and s.rhs == 1
|
||||
):
|
||||
assert self.shape_env is not None
|
||||
# pyrefly: ignore # unbound-name
|
||||
self.shape_env.set_unbacked_var_to_val(s, int(real_t))
|
||||
|
||||
if real_out is not nil:
|
||||
|
@ -1820,6 +1820,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
|
||||
# TODO: Use a valid grad-specific symbolic context instead of recycling
|
||||
# the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view().
|
||||
# pyrefly: ignore # unbound-name
|
||||
r.grad = self.meta_tensor(
|
||||
t.grad,
|
||||
shape_env,
|
||||
@ -1827,12 +1828,15 @@ class MetaConverter(Generic[_TensorT]):
|
||||
AttrSource(source, "grad"),
|
||||
symbolic_context,
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
torch._C._set_conj(r, t.is_conj)
|
||||
# pyrefly: ignore # unbound-name
|
||||
torch._C._set_neg(r, t.is_neg)
|
||||
# This can be skipped if necessary for performance reasons
|
||||
skip_leaf = (
|
||||
t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
|
||||
# Thanks to storage resizing, it's possible to end up with a tensor
|
||||
# that advertises a real size, but has a storage that actually has zero bytes.
|
||||
@ -1840,14 +1844,18 @@ class MetaConverter(Generic[_TensorT]):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
if t.storage is not None and guard_or_false(t.storage.size == 0):
|
||||
# pyrefly: ignore # unbound-name
|
||||
r.untyped_storage().resize_(0)
|
||||
|
||||
if t.is_parameter:
|
||||
# pyrefly: ignore # unbound-name
|
||||
r._is_param = True
|
||||
|
||||
# See Note: [Creating symbolic nested int]
|
||||
if t.nested_int is not None:
|
||||
# pyrefly: ignore # unbound-name
|
||||
assert _is_fake_tensor(r)
|
||||
# pyrefly: ignore # unbound-name
|
||||
r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
|
||||
nt_tensor_id=t.nested_int
|
||||
)
|
||||
|
@ -1120,6 +1120,7 @@ class Tensor(torch._C.TensorBase):
|
||||
__rtruediv__ = __rdiv__
|
||||
__itruediv__ = _C.TensorBase.__idiv__
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
__pow__ = cast(
|
||||
Callable[
|
||||
["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]],
|
||||
|
@ -657,8 +657,10 @@ def _str_intern(inp, *, tensor_contents=None):
|
||||
grad_fn_name = "Invalid"
|
||||
|
||||
if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
grad_fn_name = type(grad_fn).__name__
|
||||
if grad_fn_name == "CppFunction":
|
||||
# pyrefly: ignore # unbound-name
|
||||
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
|
||||
|
||||
if grad_fn_name is not None:
|
||||
|
@ -89,6 +89,7 @@ def compile_time_strobelight_meta(
|
||||
skip := kwargs["skip"],
|
||||
int,
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
kwargs["skip"] = skip + 1
|
||||
|
||||
# This is not needed but we have it here to avoid having profile_compile_time
|
||||
|
@ -951,6 +951,7 @@ def create_a_shadows_b(
|
||||
if should_log_inputs:
|
||||
# skip the input logger when inserting a dtype cast
|
||||
if isinstance(prev_node_c, Node):
|
||||
# pyrefly: ignore # unbound-name
|
||||
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
|
||||
elif isinstance(prev_node_c, list):
|
||||
prev_node_c = [
|
||||
@ -959,6 +960,7 @@ def create_a_shadows_b(
|
||||
]
|
||||
dtype_cast_node = _insert_dtype_cast_after_node(
|
||||
subgraph_a.start_node,
|
||||
# pyrefly: ignore # unbound-name
|
||||
node_c,
|
||||
prev_node_c,
|
||||
gm_a,
|
||||
@ -1039,7 +1041,10 @@ def create_a_shadows_b(
|
||||
if num_non_param_args_node_a == 2:
|
||||
# node_c_second_non_param_arg = node_c.args[1]
|
||||
node_c_second_non_param_arg = get_normalized_nth_input(
|
||||
node_c, gm_b, 1
|
||||
# pyrefly: ignore # unbound-name
|
||||
node_c,
|
||||
gm_b,
|
||||
1,
|
||||
)
|
||||
node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
|
||||
dtype_cast_node,
|
||||
@ -1047,6 +1052,7 @@ def create_a_shadows_b(
|
||||
subgraph_a,
|
||||
gm_a,
|
||||
gm_b,
|
||||
# pyrefly: ignore # unbound-name
|
||||
node_c.name + "_shadow_copy_",
|
||||
)
|
||||
env_c[node_a_shadows_c.name] = node_a_shadows_c
|
||||
@ -1069,11 +1075,15 @@ def create_a_shadows_b(
|
||||
cur_node = node_a_shadows_c
|
||||
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
|
||||
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
|
||||
# pyrefly: ignore # unbound-name
|
||||
if isinstance(input_logger, Node):
|
||||
# pyrefly: ignore # unbound-name
|
||||
input_logger_mod = getattr(gm_b, input_logger.name)
|
||||
input_logger_mod.ref_node_name = cur_node.name
|
||||
else:
|
||||
# pyrefly: ignore # unbound-name
|
||||
assert isinstance(input_logger, list)
|
||||
# pyrefly: ignore # unbound-name
|
||||
for input_logger_inner in input_logger:
|
||||
input_logger_mod = getattr(gm_b, input_logger_inner.name)
|
||||
input_logger_mod.ref_node_name = cur_node.name
|
||||
|
@ -93,6 +93,7 @@ class OutputProp:
|
||||
)
|
||||
|
||||
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
node.traced_result = result
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
|
@ -404,7 +404,7 @@ def maybe_add_missing_fqns(results: NSResultsType) -> None:
|
||||
for model_name, model_results in model_name_to_results.items():
|
||||
if model_name == model_name_with_fqns:
|
||||
continue
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for i in range(len(model_results)):
|
||||
fqn = ref_model_results[i]["fqn"]
|
||||
model_results[i]["fqn"] = fqn
|
||||
|
@ -27,6 +27,7 @@ def run_forward(model, **batch):
|
||||
model(X, lS_o, lS_i)
|
||||
end = time.time()
|
||||
time_taken = end - start
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
time_list.append(time_taken)
|
||||
avg_time = np.mean(time_list[1:])
|
||||
return avg_time
|
||||
|
@ -127,6 +127,7 @@ def _prune_linear_helper(linear: nn.Linear) -> Tensor:
|
||||
linear.out_features = linear.weight.shape[0]
|
||||
_remove_bias_handles(linear)
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
return mask
|
||||
|
||||
|
||||
@ -185,6 +186,7 @@ def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
|
||||
conv2d.out_channels = conv2d.weight.shape[0]
|
||||
|
||||
_remove_bias_handles(conv2d)
|
||||
# pyrefly: ignore # unbound-name
|
||||
return mask
|
||||
|
||||
|
||||
@ -205,6 +207,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
|
||||
new_bias = torch.zeros(conv2d_1.bias.shape)
|
||||
new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined]
|
||||
# adjusted bias that to keep in conv2d_1
|
||||
# pyrefly: ignore # unbound-name
|
||||
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
|
||||
# pruned biases that are kept instead of propagated
|
||||
conv2d_1.bias = nn.Parameter(new_bias)
|
||||
|
@ -72,7 +72,6 @@ def _find_q_dq_node_for_user(
|
||||
dq_node = n
|
||||
break
|
||||
if dq_node is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for n in user.kwargs:
|
||||
if (
|
||||
isinstance(n, torch.fx.Node)
|
||||
@ -91,6 +90,7 @@ def _find_q_dq_node_for_user(
|
||||
and arg.op == "call_function"
|
||||
and arg.target in _QUANTIZE_OPS
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
q_node = arg
|
||||
return (q_node, dq_node)
|
||||
|
||||
|
@ -414,5 +414,6 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def __exit__(self, *args) -> None:
|
||||
torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions)
|
||||
|
@ -48,14 +48,11 @@ class EventList(list):
|
||||
def _remove_dup_nodes(self):
|
||||
while True:
|
||||
to_delete = set()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for idx in range(len(self)):
|
||||
if (
|
||||
# pyrefly: ignore # index-error
|
||||
self[idx].cpu_parent is not None
|
||||
# pyrefly: ignore # index-error
|
||||
and self[idx].cpu_parent.name == self[idx].name
|
||||
# pyrefly: ignore # index-error
|
||||
and len(self[idx].cpu_parent.cpu_children) == 1
|
||||
):
|
||||
self[idx].cpu_parent.cpu_children = self[idx].cpu_children
|
||||
@ -65,11 +62,11 @@ class EventList(list):
|
||||
to_delete.add(idx)
|
||||
if len(to_delete) == 0:
|
||||
break
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
|
||||
new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.clear()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.extend(new_evts)
|
||||
|
||||
def _populate_cpu_children(self):
|
||||
|
@ -90,6 +90,7 @@ class CacheArtifactFactory:
|
||||
@classmethod
|
||||
def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact:
|
||||
artifact_cls = cls._get_artifact_type(artifact_type_key)
|
||||
# pyrefly: ignore # bad-instantiation
|
||||
return artifact_cls(key, content)
|
||||
|
||||
@classmethod
|
||||
@ -97,6 +98,7 @@ class CacheArtifactFactory:
|
||||
cls, artifact_type_key: str, key: str, content: Any
|
||||
) -> CacheArtifact:
|
||||
artifact_cls = cls._get_artifact_type(artifact_type_key)
|
||||
# pyrefly: ignore # bad-instantiation
|
||||
return artifact_cls(key, artifact_cls.encode(content))
|
||||
|
||||
|
||||
|
@ -377,6 +377,7 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
|
||||
|
||||
nn_module_stack = {
|
||||
root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),
|
||||
# pyrefly: ignore # unbound-name
|
||||
**nn_module_stack,
|
||||
}
|
||||
node.meta["nn_module_stack"] = {
|
||||
@ -525,6 +526,7 @@ def _replace_unbacked_bindings(gm: torch.fx.GraphModule) -> None:
|
||||
simplify=True,
|
||||
)
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
node.meta["unbacked_bindings"] = unbacked_bindings
|
||||
|
||||
|
||||
@ -662,7 +664,6 @@ def _rename_constants_nodes(
|
||||
if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
|
||||
const_prefix
|
||||
):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants
|
||||
c_name = rename_constant(
|
||||
const_prefix + spec.arg.name[len(buffer_prefix) :]
|
||||
|
@ -332,7 +332,9 @@ class _TorchNumpyPickleData:
|
||||
if not (name := getattr(np, "__name__", None)):
|
||||
return None
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
assert np == getattr(importlib.import_module(mod), name)
|
||||
# pyrefly: ignore # unbound-name
|
||||
return cls(mod, name)
|
||||
|
||||
|
||||
|
@ -1793,17 +1793,14 @@ class _ModuleStackTracer(PythonKeyTracer):
|
||||
self.enable_attr_proxy = False
|
||||
self.submodule_paths = {}
|
||||
for name, m in self.scope_root.named_modules(remove_duplicate=False):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if m in self.submodule_paths:
|
||||
log.info(
|
||||
"Shared module found between %s and %s, AttrProxy is enabled.",
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.submodule_paths[m],
|
||||
name,
|
||||
)
|
||||
self.enable_attr_proxy = True
|
||||
else:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.submodule_paths[m] = name
|
||||
|
||||
self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary()
|
||||
@ -2365,6 +2362,7 @@ class _MakefxTracer:
|
||||
):
|
||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx")
|
||||
t.recompile()
|
||||
# TODO: kind of a bad way to do it, should maybe figure out a better way
|
||||
|
@ -620,11 +620,13 @@ def rebind_unbacked(
|
||||
):
|
||||
# This is what the pattern match above is testing
|
||||
repacked = _sympy_cast_symbool_to_symint_guardless(
|
||||
# pyrefly: ignore # unbound-name
|
||||
sympy.Eq(new_raw_u1, 1)
|
||||
)
|
||||
assert repacked == raw_u1, f"{repacked} != {raw_u1}"
|
||||
# Cancel the to_int(to_bool(x)). This is sound because x in
|
||||
# [0, 1]
|
||||
# pyrefly: ignore # unbound-name
|
||||
raw_u1 = new_raw_u1
|
||||
|
||||
if not isinstance(raw_u1, sympy.Symbol):
|
||||
@ -1025,6 +1027,7 @@ def find_symbol_binding_fx_nodes(
|
||||
# NB: Prefer first occurrence of symbol
|
||||
for node in graph.nodes:
|
||||
if (s := is_symbol_binding_fx_node(node)) is not None and s not in r:
|
||||
# pyrefly: ignore # unbound-name
|
||||
r[s] = node
|
||||
return r
|
||||
|
||||
@ -1195,10 +1198,13 @@ def _free_unbacked_symbols_with_path(
|
||||
and isinstance(s := expr(a), sympy.Symbol)
|
||||
and s in pending
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
r[s] = path
|
||||
if shape_env and real is not None:
|
||||
assert isinstance(real, (int, float))
|
||||
# pyrefly: ignore # unbound-name
|
||||
shape_env.set_unbacked_var_to_val(s, real)
|
||||
# pyrefly: ignore # unbound-name
|
||||
pending.remove(s)
|
||||
# When an unbacked SymInt is perfectly divisible by an integer
|
||||
# constant, we replace it with the integer constant to improve
|
||||
@ -1228,20 +1234,27 @@ def _free_unbacked_symbols_with_path(
|
||||
source=shape_env.var_to_sources.get(s, [None])[0], # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
unbacked = lhs if lhs in pending else rhs
|
||||
divisor: IntLikeType = (
|
||||
# pyrefly: ignore # unbound-name
|
||||
int(coeff)
|
||||
# pyrefly: ignore # unbound-name
|
||||
if shape_env and isinstance(coeff, sympy.Integer)
|
||||
# pyrefly: ignore # unbound-name
|
||||
else _symint_wrap(coeff)
|
||||
)
|
||||
# TODO: DivideByKey needs to test divisibility at runtime!
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
|
||||
r[unbacked] = path + (DivideByKey(divisor),)
|
||||
if real is not None:
|
||||
assert isinstance(real, int)
|
||||
val = (
|
||||
# pyrefly: ignore # unbound-name
|
||||
real // int(coeff)
|
||||
# pyrefly: ignore # unbound-name
|
||||
if isinstance(coeff, sympy.Integer)
|
||||
# pyrefly: ignore # unbound-name
|
||||
else CleanDiv(real, coeff)
|
||||
)
|
||||
if shape_env:
|
||||
@ -1263,7 +1276,9 @@ def _free_unbacked_symbols_with_path(
|
||||
if real is not None:
|
||||
assert type(real) is bool
|
||||
if shape_env:
|
||||
# pyrefly: ignore # unbound-name
|
||||
shape_env.set_unbacked_var_to_val(s, int(real))
|
||||
# pyrefly: ignore # unbound-name
|
||||
pending.remove(s.lhs)
|
||||
|
||||
return r
|
||||
@ -1339,6 +1354,7 @@ def compute_unbacked_bindings(
|
||||
):
|
||||
if (
|
||||
isinstance(old_sym, SymTypes)
|
||||
# pyrefly: ignore # unbound-name
|
||||
and (old_s := old_sym.node.expr) != new_s
|
||||
):
|
||||
# If old_s is not an unbacked_symbol,
|
||||
@ -1348,11 +1364,15 @@ def compute_unbacked_bindings(
|
||||
# and the original symbol gets replaced by the backed symbol.
|
||||
# When this happens we just replace new_s by the old_s
|
||||
# because we know the value is the same.
|
||||
# pyrefly: ignore # unbound-name
|
||||
if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s):
|
||||
# pyrefly: ignore # unbound-name
|
||||
shape_env._rename_unbacked_to(new_s, old_s)
|
||||
else:
|
||||
# pyrefly: ignore # unbound-name
|
||||
shape_env._eliminate_unbacked(new_s, old_s)
|
||||
elif not isinstance(old_sym, SymTypes):
|
||||
# pyrefly: ignore # unbound-name
|
||||
shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym))
|
||||
|
||||
return symbol_to_path
|
||||
@ -3317,6 +3337,7 @@ class DimConstraints:
|
||||
and str(symbol := next(iter(c["eq"].free_symbols))) == old_root
|
||||
): # derived dim with root = old_root
|
||||
new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1
|
||||
# pyrefly: ignore # unbound-name
|
||||
new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1
|
||||
c["eq"] = new_expr
|
||||
|
||||
@ -5313,7 +5334,7 @@ class ShapeEnv:
|
||||
]
|
||||
else:
|
||||
assert len(input_contexts) == len(placeholders)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
|
||||
if isinstance(t, Tensorlike):
|
||||
if context is None:
|
||||
@ -5663,13 +5684,12 @@ class ShapeEnv:
|
||||
)
|
||||
track_symint(property_source, ss, constraint_size[i])
|
||||
else:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for i, ss in enumerate(curr_t.size()):
|
||||
property_source = TensorPropertySource(
|
||||
src, TensorProperty.SIZE, i
|
||||
)
|
||||
track_symint(property_source, ss, constraint_size[i])
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
for i, ss in enumerate(curr_t.stride()):
|
||||
property_source = TensorPropertySource(
|
||||
src, TensorProperty.STRIDE, i
|
||||
@ -5677,7 +5697,6 @@ class ShapeEnv:
|
||||
track_symint(property_source, ss, constraint_stride[i])
|
||||
track_symint(
|
||||
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
curr_t.storage_offset(),
|
||||
)
|
||||
|
||||
@ -5723,7 +5742,6 @@ class ShapeEnv:
|
||||
continue
|
||||
|
||||
if is_dim(source):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.dim_constraints.add_equality(source, expr)
|
||||
|
||||
for exprs, printer, lang in zip(all_exprs, printers, langs):
|
||||
@ -5877,7 +5895,6 @@ class ShapeEnv:
|
||||
continue
|
||||
expr = self.simplify(ra.expr)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.dim_constraints.add(expr)
|
||||
|
||||
# 3. Every symbol must be within its value range (this handles 0/1
|
||||
@ -5894,7 +5911,6 @@ class ShapeEnv:
|
||||
verbose_expr = ""
|
||||
if r.lower not in (-sympy.oo, -int_oo):
|
||||
if any(is_dim(source) for source in sources):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
|
||||
# Only print lower bound in simplified mode if it is not the
|
||||
# default
|
||||
@ -5903,7 +5919,6 @@ class ShapeEnv:
|
||||
verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}"
|
||||
if r.upper not in (sympy.oo, int_oo):
|
||||
if any(is_dim(source) for source in sources):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.dim_constraints.add(sympy.Le(symbol, r.upper))
|
||||
# nontrivial upper bound is always interesting
|
||||
bounds.append(sympy.Le(symbol, r.upper, evaluate=False))
|
||||
@ -6152,7 +6167,6 @@ class ShapeEnv:
|
||||
else:
|
||||
bindings[-s] = -arg
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for t, arg in zip(placeholders, args):
|
||||
if t is None:
|
||||
continue
|
||||
@ -7588,8 +7602,10 @@ class ShapeEnv:
|
||||
log.info(
|
||||
"oblivious_size %s -> %s (passed counterfactual)",
|
||||
orig_expr,
|
||||
# pyrefly: ignore # unbound-name
|
||||
correct_hint,
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
concrete_val = correct_hint
|
||||
# NB: do NOT transmute into runtime assert
|
||||
ok = True
|
||||
@ -7606,8 +7622,10 @@ class ShapeEnv:
|
||||
).xreplace(self.var_to_val)
|
||||
).free_symbols
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
self._log_real_tensor_propagation(orig_expr, unsound_result)
|
||||
transmute_into_runtime_assert = True
|
||||
# pyrefly: ignore # unbound-name
|
||||
concrete_val = unsound_result
|
||||
ok = True
|
||||
|
||||
@ -8035,7 +8053,6 @@ def _suggest_fixes_for_data_dependent_error_non_strict(
|
||||
if isinstance(leaf, torch.SymInt):
|
||||
src_map[str(leaf.node.expr)].append(name)
|
||||
elif isinstance(leaf, torch.Tensor):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i, dim in enumerate(leaf.shape):
|
||||
if isinstance(dim, torch.SymInt):
|
||||
src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]")
|
||||
|
@ -407,6 +407,7 @@ class MethodDispatcher(Dispatcher):
|
||||
Dispatcher
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
__slots__ = ("obj", "cls")
|
||||
|
||||
@classmethod
|
||||
|
@ -120,7 +120,7 @@ def _torchscript_schema_to_signature_impl(
|
||||
# which makes it hard to do type annotation
|
||||
kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment]
|
||||
# This renders all previous arguments to positional only
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for idx, p in enumerate(parameters):
|
||||
assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
|
||||
parameters[idx] = Parameter(
|
||||
@ -129,7 +129,7 @@ def _torchscript_schema_to_signature_impl(
|
||||
default=p.default,
|
||||
annotation=p.annotation,
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
parameters.append(
|
||||
Parameter(name=name, kind=kind, default=default, annotation=arg_type)
|
||||
)
|
||||
@ -143,7 +143,6 @@ def _torchscript_schema_to_signature_impl(
|
||||
else:
|
||||
return_type = tuple(return_types)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return inspect.Signature(parameters, return_annotation=return_type)
|
||||
|
||||
|
||||
|
@ -241,7 +241,6 @@ def tensorify_python_scalars(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
val = node.meta.get("val")
|
||||
if isinstance(val, FakeTensor):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for dim in val.shape:
|
||||
if isinstance(dim, torch.SymInt):
|
||||
for s in dim.node.expr.free_symbols:
|
||||
@ -277,6 +276,7 @@ def tensorify_python_scalars(
|
||||
):
|
||||
transform = True
|
||||
try:
|
||||
# pyrefly: ignore # unbound-name
|
||||
proxy = _sympy_interp(zf.node.expr)
|
||||
except NotImplementedError:
|
||||
transform = False
|
||||
@ -303,6 +303,7 @@ def tensorify_python_scalars(
|
||||
args.append(a)
|
||||
|
||||
if transform:
|
||||
# pyrefly: ignore # unbound-name
|
||||
replacement_proxy = replacement_op(*args)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
@ -93,6 +93,7 @@ class FakeTensorProp(torch.fx.Interpreter):
|
||||
if (shape_env := self._mode.shape_env) and (
|
||||
symbol_to_path := compute_unbacked_bindings(shape_env, result)
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
n.meta["unbacked_bindings"] = symbol_to_path
|
||||
|
||||
return result
|
||||
|
@ -274,7 +274,6 @@ class PassManager:
|
||||
logger.debug("Running pass '%s'", fn_name)
|
||||
|
||||
try:
|
||||
# pyrefly: ignore # not-callable
|
||||
res = fn(module)
|
||||
|
||||
if not isinstance(res, PassResult) and not hasattr(
|
||||
|
@ -395,21 +395,25 @@ class _MinimizerBase:
|
||||
report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
|
||||
if self.module_exporter:
|
||||
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
result_key = result_key[-1]
|
||||
# If the result is still a tuple (happens in non-sequential mode),
|
||||
# we only use the first element as name.
|
||||
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
result_key = str(result_key[0])
|
||||
# pyre-ignore[29]: not a function
|
||||
self.module_exporter(
|
||||
a_input,
|
||||
submodule,
|
||||
# pyrefly: ignore # unbound-name
|
||||
result_key + "_cpu",
|
||||
)
|
||||
# pyre-ignore[29]: not a function
|
||||
self.module_exporter(
|
||||
b_input,
|
||||
submodule,
|
||||
# pyrefly: ignore # unbound-name
|
||||
result_key + "_acc",
|
||||
)
|
||||
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
|
||||
|
@ -298,10 +298,14 @@ def insert_deferred_runtime_asserts(
|
||||
and s not in expr_to_proxy
|
||||
):
|
||||
with _set_node_metadata_hook(gm, _node_metadata_hook):
|
||||
# pyrefly: ignore # unbound-name
|
||||
expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
|
||||
# pyrefly: ignore # unbound-name
|
||||
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
match_symbol(example_value, lambda: node)
|
||||
# pyrefly: ignore # unbound-name
|
||||
if isinstance(t := example_value, torch.Tensor):
|
||||
for i, s in enumerate(t.size()):
|
||||
match_symbol(
|
||||
@ -382,6 +386,7 @@ def insert_deferred_runtime_asserts(
|
||||
|
||||
# maybe re-reify expression, replace current node
|
||||
if (
|
||||
# pyrefly: ignore # unbound-name
|
||||
sym_expr in expr_to_proxy
|
||||
or ( # example value is redundant
|
||||
_is_intermediate_tensor_sym_call(node)
|
||||
@ -400,20 +405,30 @@ def insert_deferred_runtime_asserts(
|
||||
nn_module_stack=node.meta.get("nn_module_stack"),
|
||||
),
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
expr_to_proxy[sym_expr] = _sympy_interp(
|
||||
expr_to_proxy, sym_expr
|
||||
expr_to_proxy,
|
||||
# pyrefly: ignore # unbound-name
|
||||
sym_expr,
|
||||
) # type: ignore[arg-type]
|
||||
# won't try DCE-ing tensor compute here
|
||||
hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
|
||||
node.replace_all_uses_with(hash_node)
|
||||
gm.graph.erase_node(node)
|
||||
log.debug(
|
||||
"CSE node %s -> %s for expr %s", node, hash_node, sym_expr
|
||||
"CSE node %s -> %s for expr %s",
|
||||
node,
|
||||
hash_node,
|
||||
# pyrefly: ignore # unbound-name
|
||||
sym_expr,
|
||||
)
|
||||
|
||||
# store node in hash cons, don't delete/replace
|
||||
# pyrefly: ignore # unbound-name
|
||||
elif sym_expr not in expr_to_proxy and not isinstance(
|
||||
sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
|
||||
# pyrefly: ignore # unbound-name
|
||||
sym_expr,
|
||||
(sympy.Number, sympy.logic.boolalg.BooleanAtom),
|
||||
): # don't hash cons primitives
|
||||
expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer) # type: ignore[arg-type]
|
||||
|
||||
|
@ -317,6 +317,7 @@ def split_module(
|
||||
and isinstance(s0 := val.node.expr, sympy.Symbol)
|
||||
and s0 not in symbol_to_node
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
symbol_to_node[val.node.expr] = node
|
||||
|
||||
if node.op in ["placeholder", "get_attr", "output"]:
|
||||
|
@ -84,6 +84,7 @@ def get_source_partitions(
|
||||
if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and (
|
||||
torch_fn := node.meta.get("torch_fn", None)
|
||||
) is not None:
|
||||
# pyrefly: ignore # unbound-name
|
||||
node_fqn, source_fn = torch_fn
|
||||
source_fn_name = source_fn.split(".")[1]
|
||||
if source_fn_name in wanted_sources:
|
||||
|
@ -288,7 +288,7 @@ def _replace_pattern(
|
||||
elif isinstance(pattern, Graph):
|
||||
pattern_graph = pattern
|
||||
else:
|
||||
pattern_graph = symbolic_trace(pattern).graph
|
||||
pattern_graph = symbolic_trace(pattern).graph # type: ignore[arg-type]
|
||||
|
||||
matcher = SubgraphMatcher(
|
||||
pattern_graph,
|
||||
@ -321,7 +321,7 @@ def _replace_pattern(
|
||||
assert replacement_callback is not None, (
|
||||
"Must provide either a replacement GraphModule or a replacement callback"
|
||||
)
|
||||
common_replacement_graph = None
|
||||
common_replacement_graph = None # type: ignore[assignment]
|
||||
|
||||
# As we progressively replace nodes, we'll need to keep track of how the match results should change
|
||||
match_changed_node: dict[Node, Node] = {}
|
||||
|
@ -561,7 +561,6 @@ def cat(tensors: list[list[int]], dim: int):
|
||||
for i in range(len(tensors)):
|
||||
tensor = tensors[i]
|
||||
if not should_skip(tensor):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
|
||||
cat_dim_size = cat_dim_size + tensor[dim]
|
||||
|
||||
|
@ -128,6 +128,7 @@ def _format_model_info(model_info: ModelInfo) -> str:
|
||||
target_to_messages = {}
|
||||
for node, message in model_info.dispatch_failures:
|
||||
if str(node.target) not in target_to_messages:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
target_to_messages[str(node.target)] = message
|
||||
|
||||
for target, nodes in sorted(
|
||||
|
@ -149,6 +149,7 @@ class ElementwiseTypePromotionRule(TypePromotionRule):
|
||||
f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})"
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def __eq__(self, other: object, /) -> bool:
|
||||
if not isinstance(other, ElementwiseTypePromotionRule):
|
||||
return False
|
||||
@ -265,6 +266,7 @@ class ReductionTypePromotionRule(TypePromotionRule):
|
||||
def __repr__(self):
|
||||
return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})"
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def __eq__(self, other: object, /) -> bool:
|
||||
if not isinstance(other, ElementwiseTypePromotionRule):
|
||||
return False
|
||||
|
@ -298,9 +298,12 @@ def _create_node(
|
||||
for key, value in sorted(attributes.items()):
|
||||
if key in _SKIP_NODE_ATTRIBUTES:
|
||||
continue
|
||||
# pyrefly: ignore # unbound-name
|
||||
_add_attribute(node, key, value, aten=aten)
|
||||
if shape_inference:
|
||||
# pyrefly: ignore # unbound-name
|
||||
_C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
|
||||
# pyrefly: ignore # unbound-name
|
||||
return node
|
||||
|
||||
|
||||
|
@ -219,7 +219,6 @@ def index_put(
|
||||
if len(indices_list) > 1:
|
||||
for idx_ in range(len(indices_list)):
|
||||
if symbolic_helper._is_bool(indices_list[idx_]):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
indices_list[idx_] = g.op("NonZero", indices_list[idx_])
|
||||
index = indices_list[0]
|
||||
|
||||
|
@ -698,7 +698,6 @@ def _multi_tensor_adam(
|
||||
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
|
||||
# Due to the strictness of the _foreach_addcmul API, we can't have a single
|
||||
|
@ -115,11 +115,14 @@ def _strong_wolfe(
|
||||
t = _cubic_interpolate(
|
||||
# pyrefly: ignore # index-error
|
||||
bracket[0],
|
||||
# pyrefly: ignore # unbound-name
|
||||
bracket_f[0],
|
||||
bracket_gtd[0], # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # index-error
|
||||
bracket[1],
|
||||
# pyrefly: ignore # unbound-name
|
||||
bracket_f[1],
|
||||
# pyrefly: ignore # unbound-name
|
||||
bracket_gtd[1],
|
||||
)
|
||||
|
||||
@ -130,14 +133,20 @@ def _strong_wolfe(
|
||||
# + `t` is at one of the boundary,
|
||||
# we will move `t` to a position which is `0.1 * len(bracket)`
|
||||
# away from the nearest boundary point.
|
||||
# pyrefly: ignore # unbound-name
|
||||
eps = 0.1 * (max(bracket) - min(bracket))
|
||||
# pyrefly: ignore # unbound-name
|
||||
if min(max(bracket) - t, t - min(bracket)) < eps:
|
||||
# interpolation close to boundary
|
||||
# pyrefly: ignore # unbound-name
|
||||
if insuf_progress or t >= max(bracket) or t <= min(bracket):
|
||||
# evaluate at 0.1 away from boundary
|
||||
# pyrefly: ignore # unbound-name
|
||||
if abs(t - max(bracket)) < abs(t - min(bracket)):
|
||||
# pyrefly: ignore # unbound-name
|
||||
t = max(bracket) - eps
|
||||
else:
|
||||
# pyrefly: ignore # unbound-name
|
||||
t = min(bracket) + eps
|
||||
insuf_progress = False
|
||||
else:
|
||||
@ -151,13 +160,17 @@ def _strong_wolfe(
|
||||
gtd_new = g_new.dot(d)
|
||||
ls_iter += 1
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
|
||||
# Armijo condition not satisfied or not lower than lowest point
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
bracket[high_pos] = t
|
||||
# pyrefly: ignore # unbound-name
|
||||
bracket_f[high_pos] = f_new
|
||||
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
bracket_gtd[high_pos] = gtd_new
|
||||
# pyrefly: ignore # unbound-name
|
||||
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
|
||||
else:
|
||||
if abs(gtd_new) <= -c2 * gtd:
|
||||
@ -168,19 +181,24 @@ def _strong_wolfe(
|
||||
# old high becomes new low
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
bracket[high_pos] = bracket[low_pos]
|
||||
# pyrefly: ignore # unbound-name
|
||||
bracket_f[high_pos] = bracket_f[low_pos]
|
||||
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
bracket_gtd[high_pos] = bracket_gtd[low_pos]
|
||||
|
||||
# new point becomes new low
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
bracket[low_pos] = t
|
||||
# pyrefly: ignore # unbound-name
|
||||
bracket_f[low_pos] = f_new
|
||||
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
bracket_gtd[low_pos] = gtd_new
|
||||
|
||||
# return stuff
|
||||
t = bracket[low_pos] # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
f_new = bracket_f[low_pos]
|
||||
g_new = bracket_g[low_pos] # type: ignore[possibly-undefined]
|
||||
return f_new, g_new, t, ls_func_evals
|
||||
|
@ -420,6 +420,7 @@ class LambdaLR(LRScheduler):
|
||||
|
||||
for idx, fn in enumerate(self.lr_lambdas):
|
||||
if not isinstance(fn, types.FunctionType):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
|
||||
|
||||
return state_dict
|
||||
@ -539,6 +540,7 @@ class MultiplicativeLR(LRScheduler):
|
||||
|
||||
for idx, fn in enumerate(self.lr_lambdas):
|
||||
if not isinstance(fn, types.FunctionType):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
|
||||
|
||||
return state_dict
|
||||
@ -1215,6 +1217,7 @@ class SequentialLR(LRScheduler):
|
||||
state_dict["_schedulers"] = [None] * len(self._schedulers)
|
||||
|
||||
for idx, s in enumerate(self._schedulers):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
state_dict["_schedulers"][idx] = s.state_dict()
|
||||
|
||||
return state_dict
|
||||
@ -1557,6 +1560,7 @@ class ChainedScheduler(LRScheduler):
|
||||
state_dict["_schedulers"] = [None] * len(self._schedulers)
|
||||
|
||||
for idx, s in enumerate(self._schedulers):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
state_dict["_schedulers"][idx] = s.state_dict()
|
||||
|
||||
return state_dict
|
||||
|
@ -337,7 +337,6 @@ def _single_tensor_sgd(
|
||||
if not torch.jit.is_scripting():
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
|
||||
@ -433,12 +432,10 @@ def _multi_tensor_sgd(
|
||||
|
||||
all_states_with_momentum_buffer = True
|
||||
for i in range(len(device_momentum_buffer_list)):
|
||||
# pyrefly: ignore # index-error
|
||||
if device_momentum_buffer_list[i] is None:
|
||||
all_states_with_momentum_buffer = False
|
||||
break
|
||||
else:
|
||||
# pyrefly: ignore # index-error
|
||||
bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
|
||||
|
||||
if all_states_with_momentum_buffer:
|
||||
@ -446,15 +443,13 @@ def _multi_tensor_sgd(
|
||||
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
|
||||
else:
|
||||
bufs = []
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for i in range(len(device_momentum_buffer_list)):
|
||||
# pyrefly: ignore # index-error
|
||||
if device_momentum_buffer_list[i] is None:
|
||||
buf = device_momentum_buffer_list[i] = momentum_buffer_list[
|
||||
indices[i]
|
||||
] = device_grads[i].detach().clone()
|
||||
else:
|
||||
# pyrefly: ignore # index-error
|
||||
buf = cast(Tensor, device_momentum_buffer_list[i])
|
||||
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
|
||||
|
||||
|
@ -672,7 +672,7 @@ class MemoryProfile:
|
||||
output: list[tuple[int, Action, KeyAndID, int]] = []
|
||||
allocation_times: dict[tuple[TensorKey, bool], int] = {}
|
||||
live_unknown: dict[tuple[int, torch.device], Literal[True]] = {}
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
for event in self._op_tree.dfs():
|
||||
if event.typed[0] == _EventType.Allocation:
|
||||
alloc_fields = event.typed[1]
|
||||
@ -774,14 +774,12 @@ class MemoryProfile:
|
||||
for key, (_, version) in node.inputs.items()
|
||||
if self._categories.get(key, version)
|
||||
in (Category.GRADIENT, Category.PARAMETER)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
or key.id in depends_on_gradient
|
||||
)
|
||||
|
||||
if ids:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
depends_on_gradient.update(ids)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
depends_on_gradient.update(key.id for key in node.outputs)
|
||||
|
||||
# We are guaranteed to exit because there is a finite set of
|
||||
@ -790,7 +788,6 @@ class MemoryProfile:
|
||||
# once to fold the first step into that loop, and a third time
|
||||
# where no new elements are added.
|
||||
if len(depends_on_gradient) == start_size:
|
||||
# pyrefly: ignore # bad-return
|
||||
return depends_on_gradient
|
||||
|
||||
def _set_gradients_and_temporaries(self) -> None:
|
||||
|
@ -140,6 +140,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
||||
|
||||
if dense.dtype != torch.float:
|
||||
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
||||
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
||||
else:
|
||||
@ -172,6 +173,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
||||
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
||||
m, meta_ncols, meta_dtype, device
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
||||
|
||||
return (sparse, meta_reordered.view(m, meta_ncols))
|
||||
|
@ -385,7 +385,7 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
|
||||
g1 = c_offsets[r + 1]
|
||||
for g in range(g0, g1):
|
||||
p, q = pq[g]
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
|
||||
accumulators[r] += blocks[p] @ others[q]
|
||||
else:
|
||||
_scatter_mm2(blocks, others, c_offsets, pq, accumulators)
|
||||
|
@ -1219,6 +1219,7 @@ def originate_pairs(
|
||||
else:
|
||||
for pair_type in pair_types:
|
||||
try:
|
||||
# pyrefly: ignore # bad-instantiation
|
||||
return [pair_type(actual, expected, id=id, **options)]
|
||||
# Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the
|
||||
# inputs. Thus, we try the next pair type.
|
||||
|
@ -95,7 +95,7 @@ from torch.utils._import_utils import _check_module_exists
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.utils import cpp_extension
|
||||
try:
|
||||
import pytest
|
||||
import pytest # type: ignore[import-not-found]
|
||||
has_pytest = True
|
||||
except ImportError:
|
||||
has_pytest = False
|
||||
|
@ -117,6 +117,7 @@ def context_decorator(ctx, func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def decorate_context(*args, **kwargs):
|
||||
# pyrefly: ignore # bad-context-manager
|
||||
with ctx_factory():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
@ -41,7 +41,10 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
|
||||
)
|
||||
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
import optree
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
|
||||
@ -706,6 +709,7 @@ def tree_map_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@ -766,6 +770,7 @@ def tree_map_only_(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@ -1079,6 +1084,7 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
|
||||
|
||||
|
||||
with python_pytree._NODE_REGISTRY_LOCK:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
python_pytree._cxx_pytree_imported = True
|
||||
args, kwargs = (), {} # type: ignore[var-annotated]
|
||||
for args, kwargs in python_pytree._cxx_pytree_pending_imports:
|
||||
|
@ -152,6 +152,7 @@ class DebugMode(TorchDispatchMode):
|
||||
super().__enter__()
|
||||
return self
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def __exit__(self, *args):
|
||||
super().__exit__(*args)
|
||||
if self.record_torchfunction:
|
||||
|
@ -60,6 +60,7 @@ def _device_constructors():
|
||||
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
||||
class DeviceContext(TorchFunctionMode):
|
||||
def __init__(self, device):
|
||||
# pyrefly: ignore # read-only
|
||||
self.device = torch.device(device)
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -35,10 +35,12 @@ def cache_method(
|
||||
if not (cache := getattr(self, cache_name, None)):
|
||||
cache = {}
|
||||
setattr(self, cache_name, cache)
|
||||
# pyrefly: ignore # unbound-name
|
||||
cached_value = cache.get(args, _cache_sentinel)
|
||||
if cached_value is not _cache_sentinel:
|
||||
return cached_value
|
||||
value = f(self, *args, **kwargs)
|
||||
# pyrefly: ignore # unbound-name
|
||||
cache[args] = value
|
||||
return value
|
||||
|
||||
|
@ -158,6 +158,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
|
||||
def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
|
||||
# MutableSet impl will iterate over other, iter over smaller of two sets
|
||||
if isinstance(other, OrderedSet) and len(self) < len(other):
|
||||
# pyrefly: ignore # unsupported-operation, bad-return
|
||||
return other & self
|
||||
return cast(OrderedSet[T], super().__and__(other))
|
||||
|
||||
|
@ -708,6 +708,7 @@ class structseq(tuple[_T_co, ...]):
|
||||
def __new__(
|
||||
cls: type[Self],
|
||||
sequence: Iterable[_T_co],
|
||||
# pyrefly: ignore # bad-function-definition
|
||||
dict: dict[str, Any] = ...,
|
||||
) -> Self:
|
||||
raise NotImplementedError
|
||||
@ -754,6 +755,7 @@ def _tuple_flatten_with_keys(
|
||||
d: tuple[T, ...],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _tuple_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
@ -767,6 +769,7 @@ def _list_flatten(d: list[T]) -> tuple[list[T], Context]:
|
||||
|
||||
def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _list_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
@ -782,6 +785,7 @@ def _dict_flatten_with_keys(
|
||||
d: dict[Any, T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _dict_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
|
||||
|
||||
@ -797,6 +801,7 @@ def _namedtuple_flatten_with_keys(
|
||||
d: NamedTuple,
|
||||
) -> tuple[list[tuple[KeyEntry, Any]], Context]:
|
||||
values, context = _namedtuple_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return (
|
||||
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
|
||||
context,
|
||||
@ -846,6 +851,7 @@ def _ordereddict_flatten_with_keys(
|
||||
d: OrderedDict[Any, T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _ordereddict_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
|
||||
|
||||
@ -870,6 +876,7 @@ def _defaultdict_flatten_with_keys(
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _defaultdict_flatten(d)
|
||||
_, dict_context = context
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
|
||||
|
||||
|
||||
@ -918,6 +925,7 @@ def _deque_flatten_with_keys(
|
||||
d: deque[T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _deque_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
@ -1547,6 +1555,7 @@ def tree_map_only(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@ -1607,6 +1616,7 @@ def tree_map_only_(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@ -1819,6 +1829,7 @@ def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
|
||||
for attr in classname.split("."):
|
||||
enum_cls = getattr(enum_cls, attr)
|
||||
enum_cls = cast(type[Enum], enum_cls)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return enum_cls[obj["name"]]
|
||||
return obj
|
||||
|
||||
|
@ -305,6 +305,7 @@ def strobelight(
|
||||
) -> Callable[_P, Optional[_R]]:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return profiler.profile(work_function, *args, **kwargs)
|
||||
|
||||
return wrapper_function
|
||||
|
@ -105,6 +105,7 @@ def _keep_float(
|
||||
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
|
||||
@functools.wraps(f)
|
||||
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
r: Union[_T, sympy.Float] = f(*args)
|
||||
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
|
||||
r, sympy.Float
|
||||
@ -112,6 +113,7 @@ def _keep_float(
|
||||
r = sympy.Float(float(r))
|
||||
return r
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
return inner
|
||||
|
||||
|
||||
@ -198,10 +200,12 @@ class FloorDiv(sympy.Function):
|
||||
|
||||
@property
|
||||
def base(self) -> sympy.Basic:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self.args[0]
|
||||
|
||||
@property
|
||||
def divisor(self) -> sympy.Basic:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self.args[1]
|
||||
|
||||
def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
|
||||
@ -370,6 +374,7 @@ class ModularIndexing(sympy.Function):
|
||||
return None
|
||||
|
||||
def _eval_is_nonnegative(self) -> Optional[bool]:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
p, q = self.args[:2]
|
||||
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
|
||||
|
||||
@ -450,6 +455,7 @@ class PythonMod(sympy.Function):
|
||||
# - floor(p / q) = 0
|
||||
# - p % q = p - floor(p / q) * q = p
|
||||
less = p < q
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if less.is_Boolean and bool(less) and r.is_positive:
|
||||
return p
|
||||
|
||||
@ -466,8 +472,11 @@ class PythonMod(sympy.Function):
|
||||
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
|
||||
|
||||
def _ccode(self, printer):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
abs_q = str(q) if self.args[1].is_positive else f"abs({q})"
|
||||
return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}"
|
||||
|
||||
@ -548,6 +557,7 @@ class CeilToInt(sympy.Function):
|
||||
return sympy.Integer(math.ceil(float(number)))
|
||||
|
||||
def _ccode(self, printer):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5)
|
||||
return f"ceil({number})"
|
||||
|
||||
@ -818,6 +828,7 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
||||
if not cond:
|
||||
return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
|
||||
if isinstance(ai, cls):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
|
||||
return a
|
||||
|
||||
@ -995,6 +1006,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
|
||||
return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_negative(self): # type:ignore[override]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return fuzzy_and(a.is_negative for a in self.args)
|
||||
|
||||
|
||||
@ -1013,6 +1025,7 @@ class Min(MinMaxBase, Application): # type: ignore[misc]
|
||||
return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_negative(self): # type:ignore[override]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return fuzzy_or(a.is_negative for a in self.args)
|
||||
|
||||
|
||||
@ -1150,7 +1163,9 @@ class IntTrueDiv(sympy.Function):
|
||||
return sympy.Float(int(base) / int(divisor))
|
||||
|
||||
def _ccode(self, printer):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
|
||||
return f"((int){base}/(int){divisor})"
|
||||
|
||||
@ -1310,9 +1325,11 @@ class Identity(sympy.Function):
|
||||
precedence = 10
|
||||
|
||||
def __repr__(self): # type: ignore[override]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return f"Identity({self.args[0]})"
|
||||
|
||||
def _eval_is_real(self):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self.args[0].is_real
|
||||
|
||||
def _eval_is_integer(self):
|
||||
@ -1320,12 +1337,15 @@ class Identity(sympy.Function):
|
||||
|
||||
def _eval_expand_identity(self, **hints):
|
||||
# Removes the identity op.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self.args[0]
|
||||
|
||||
def __int__(self) -> int:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return int(self.args[0])
|
||||
|
||||
def __float__(self) -> float:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return float(self.args[0])
|
||||
|
||||
|
||||
|
@ -9,6 +9,7 @@ from sympy.core.parameters import global_parameters
|
||||
from sympy.core.singleton import S, Singleton
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
class IntInfinity(Number, metaclass=Singleton):
|
||||
r"""Positive integer infinite quantity.
|
||||
|
||||
@ -203,6 +204,7 @@ class IntInfinity(Number, metaclass=Singleton):
|
||||
int_oo = S.IntInfinity
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
class NegativeIntInfinity(Number, metaclass=Singleton):
|
||||
"""Negative integer infinite quantity.
|
||||
|
||||
|
@ -66,6 +66,7 @@ class ExprPrinter(StrPrinter):
|
||||
# NB: this pow by natural, you should never have used builtin sympy.pow
|
||||
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
||||
# means exp is guaranteed to be integer.
|
||||
# pyrefly: ignore # bad-override
|
||||
def _print_Pow(self, expr: sympy.Expr) -> str:
|
||||
base, exp = expr.args
|
||||
assert exp == int(exp), exp
|
||||
|
@ -175,6 +175,7 @@ class ReferenceAnalysis:
|
||||
|
||||
@staticmethod
|
||||
def pow(a, b):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return _keep_float(FloatPow)(a, b)
|
||||
|
||||
@staticmethod
|
||||
|
@ -123,7 +123,9 @@ AllFn2 = Union[ExprFn2, BoolFn2]
|
||||
class ValueRanges(Generic[_T]):
|
||||
if TYPE_CHECKING:
|
||||
# ruff doesn't understand circular references but mypy does
|
||||
# pyrefly: ignore # unbound-name
|
||||
ExprVR = ValueRanges[sympy.Expr] # noqa: F821
|
||||
# pyrefly: ignore # unbound-name
|
||||
BoolVR = ValueRanges[SympyBoolean] # noqa: F821
|
||||
AllVR = Union[ExprVR, BoolVR]
|
||||
|
||||
@ -464,6 +466,7 @@ class SymPyValueRangeAnalysis:
|
||||
@staticmethod
|
||||
def to_dtype(a, dtype, src_dtype=None):
|
||||
if dtype == torch.float64:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ValueRanges.increasing_map(a, ToFloat)
|
||||
elif dtype == torch.bool:
|
||||
return ValueRanges.unknown_bool()
|
||||
@ -473,6 +476,7 @@ class SymPyValueRangeAnalysis:
|
||||
|
||||
@staticmethod
|
||||
def trunc_to_int(a, dtype):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ValueRanges.increasing_map(a, TruncToInt)
|
||||
|
||||
@staticmethod
|
||||
@ -621,7 +625,10 @@ class SymPyValueRangeAnalysis:
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges.coordinatewise_monotone_map(
|
||||
a, b, _keep_float(IntTrueDiv)
|
||||
a,
|
||||
b,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_keep_float(IntTrueDiv),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -634,7 +641,10 @@ class SymPyValueRangeAnalysis:
|
||||
return ValueRanges.unknown()
|
||||
else:
|
||||
return ValueRanges.coordinatewise_monotone_map(
|
||||
a, b, _keep_float(FloatTrueDiv)
|
||||
a,
|
||||
b,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_keep_float(FloatTrueDiv),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -713,6 +723,7 @@ class SymPyValueRangeAnalysis:
|
||||
# We should know that b >= 0 but we may have forgotten this fact due
|
||||
# to replacements, so don't assert it, but DO clamp it to prevent
|
||||
# degenerate problems
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return ValueRanges.coordinatewise_increasing_map(
|
||||
a, b & ValueRanges(0, int_oo), PowByNatural
|
||||
)
|
||||
@ -879,6 +890,7 @@ class SymPyValueRangeAnalysis:
|
||||
|
||||
@classmethod
|
||||
def round_to_int(cls, number, dtype):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ValueRanges.increasing_map(number, RoundToInt)
|
||||
|
||||
# It's used in some models on symints
|
||||
@ -992,6 +1004,7 @@ class SymPyValueRangeAnalysis:
|
||||
|
||||
@staticmethod
|
||||
def trunc(x):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return ValueRanges.increasing_map(x, TruncToFloat)
|
||||
|
||||
|
||||
|
@ -202,6 +202,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
|
||||
Args:
|
||||
device (int, optional): if specified, all parameters will be copied to that device
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return self._apply(lambda t: getattr(t, custom_backend_name)(device))
|
||||
|
||||
_check_register_once(torch.nn.Module, custom_backend_name)
|
||||
|
@ -63,6 +63,7 @@ def generate_coo_data(size, sparse_dim, nnz, dtype, device):
|
||||
indices = torch.rand(sparse_dim, nnz, device=device)
|
||||
indices.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(indices))
|
||||
indices = indices.to(torch.long)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
values = torch.rand([nnz, ], dtype=dtype, device=device)
|
||||
return indices, values
|
||||
|
||||
|
@ -15,6 +15,7 @@ _warned_tensor_cores = False
|
||||
_default_float_32_precision = torch.get_float32_matmul_precision()
|
||||
|
||||
try:
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
HAS_TABULATE = True
|
||||
@ -169,6 +170,7 @@ if HAS_TABULATE:
|
||||
_disable_tensor_cores()
|
||||
table.append([
|
||||
("Training" if optimizer else "Inference"),
|
||||
# pyrefly: ignore # redundant-condition
|
||||
backend if backend else "-",
|
||||
mode if mode is not None else "-",
|
||||
f"{compilation_time} ms " if compilation_time else "-",
|
||||
@ -189,4 +191,5 @@ if HAS_TABULATE:
|
||||
])
|
||||
|
||||
|
||||
# pyrefly: ignore # not-callable
|
||||
return tabulate(table, headers=field_names, tablefmt="github")
|
||||
|
@ -35,6 +35,7 @@ def _get_build_root() -> str:
|
||||
global _BUILD_ROOT
|
||||
if _BUILD_ROOT is None:
|
||||
_BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build")
|
||||
# pyrefly: ignore # missing-argument
|
||||
atexit.register(shutil.rmtree, _BUILD_ROOT)
|
||||
return _BUILD_ROOT
|
||||
|
||||
|
@ -91,6 +91,7 @@ class FuzzedSparseTensor(FuzzedTensor):
|
||||
return x
|
||||
|
||||
def _make_tensor(self, params, state):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
size, _, _ = self._get_size_and_steps(params)
|
||||
density = params['density']
|
||||
nnz = math.ceil(sum(size) * density)
|
||||
@ -99,8 +100,10 @@ class FuzzedSparseTensor(FuzzedTensor):
|
||||
is_coalesced = params['coalesced']
|
||||
sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size)
|
||||
sparse_dim = min(sparse_dim, len(size))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if self._cuda:
|
||||
tensor = tensor.cuda()
|
||||
sparse_dim = tensor.sparse_dim()
|
||||
@ -116,6 +119,7 @@ class FuzzedSparseTensor(FuzzedTensor):
|
||||
"sparse_dim": sparse_dim,
|
||||
"dense_dim": dense_dim,
|
||||
"is_hybrid": is_hybrid,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
"dtype": str(self._dtype),
|
||||
}
|
||||
return tensor, properties
|
||||
|
@ -233,6 +233,7 @@ class Timer:
|
||||
setup = textwrap.dedent(setup)
|
||||
setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip()
|
||||
|
||||
# pyrefly: ignore # bad-instantiation
|
||||
self._timer = self._timer_cls(
|
||||
stmt=stmt,
|
||||
setup=setup,
|
||||
|
@ -448,11 +448,13 @@ class GlobalsBridge:
|
||||
load_lines = []
|
||||
for name, wrapped_value in self._globals.items():
|
||||
if wrapped_value.setup is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
load_lines.append(textwrap.dedent(wrapped_value.setup))
|
||||
|
||||
if wrapped_value.serialization == Serialization.PICKLE:
|
||||
path = os.path.join(self._data_dir, f"{name}.pkl")
|
||||
load_lines.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)")
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(wrapped_value.value, f)
|
||||
@ -462,11 +464,13 @@ class GlobalsBridge:
|
||||
# TODO: Figure out if we can use torch.serialization.add_safe_globals here
|
||||
# Using weights_only=False after the change in
|
||||
# https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)")
|
||||
torch.save(wrapped_value.value, path)
|
||||
|
||||
elif wrapped_value.serialization == Serialization.TORCH_JIT:
|
||||
path = os.path.join(self._data_dir, f"{name}.pt")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
load_lines.append(f"{name} = torch.jit.load({repr(path)})")
|
||||
with open(path, "wb") as f:
|
||||
torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call]
|
||||
|
@ -222,6 +222,7 @@ def _get_autocast_kwargs(device_type="cuda"):
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
@ -784,6 +785,7 @@ class _Holder:
|
||||
|
||||
class _NoopSaveInputs(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(*args):
|
||||
return torch.empty((0,))
|
||||
|
||||
@ -1006,6 +1008,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
|
||||
def logging_mode():
|
||||
with LoggingTensorMode(), \
|
||||
capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.logs, self.tbs = logs_and_tb
|
||||
yield logs_and_tb
|
||||
return logging_mode()
|
||||
|
@ -787,6 +787,7 @@ class BuildExtension(build_ext):
|
||||
|
||||
# Use absolute path for output_dir so that the object file paths
|
||||
# (`objects`) get generated with absolute paths.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
|
||||
# See Note [Absolute include_dirs]
|
||||
@ -977,6 +978,7 @@ class BuildExtension(build_ext):
|
||||
is_standalone=False):
|
||||
if not self.compiler.initialized:
|
||||
self.compiler.initialize()
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
|
||||
# Note [Absolute include_dirs]
|
||||
@ -1528,6 +1530,7 @@ def include_paths(device_type: str = "cpu", torch_include_dirs=True) -> list[str
|
||||
# Support CUDA_INC_PATH env variable supported by CMake files
|
||||
if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \
|
||||
cuda_inc_path != '/usr/include':
|
||||
# pyrefly: ignore # unbound-name
|
||||
paths.append(cuda_inc_path)
|
||||
if CUDNN_HOME is not None:
|
||||
paths.append(os.path.join(CUDNN_HOME, 'include'))
|
||||
@ -2569,6 +2572,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]:
|
||||
def _get_vc_env(vc_arch: str) -> dict[str, str]:
|
||||
try:
|
||||
from setuptools import distutils # type: ignore[attr-defined]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return distutils._msvccompiler._get_vc_env(vc_arch)
|
||||
except AttributeError:
|
||||
try:
|
||||
|
@ -204,6 +204,7 @@ def collate(
|
||||
# check to make sure that the elements in batch have consistent size
|
||||
it = iter(batch)
|
||||
elem_size = len(next(it))
|
||||
# pyrefly: ignore # not-iterable
|
||||
if not all(len(elem) == elem_size for elem in it):
|
||||
raise RuntimeError("each element in list of batch should be of equal size")
|
||||
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
|
||||
|
@ -70,6 +70,7 @@ def pin_memory(data, device=None):
|
||||
return clone
|
||||
else:
|
||||
return type(data)(
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
{k: pin_memory(sample, device) for k, sample in data.items()}
|
||||
) # type: ignore[call-arg]
|
||||
except TypeError:
|
||||
|
@ -674,6 +674,7 @@ class _BaseDataLoaderIter:
|
||||
|
||||
# Set pin memory device based on the current accelerator.
|
||||
self._pin_memory_device = (
|
||||
# pyrefly: ignore # unbound-name
|
||||
acc.type
|
||||
if self._pin_memory
|
||||
and (acc := torch.accelerator.current_accelerator()) is not None
|
||||
|
@ -265,6 +265,7 @@ class _DataPipeType:
|
||||
|
||||
# Default type for DataPipe without annotation
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
_DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
|
||||
|
||||
|
||||
@ -283,6 +284,7 @@ class _DataPipeMeta(GenericMeta):
|
||||
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
||||
|
||||
# TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
|
||||
# pyrefly: ignore # no-access
|
||||
cls.__origin__ = None
|
||||
if "type" in namespace:
|
||||
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
||||
|
@ -80,6 +80,7 @@ class Capture:
|
||||
|
||||
def _ops_str(self):
|
||||
res = ""
|
||||
# pyrefly: ignore # not-iterable
|
||||
for op in self.ctx["operations"]:
|
||||
if len(res) > 0:
|
||||
res += "\n"
|
||||
@ -89,6 +90,7 @@ class Capture:
|
||||
def __getstate__(self):
|
||||
# TODO(VitalyFedyunin): Currently can't pickle (why?)
|
||||
self.ctx["schema_df"] = None
|
||||
# pyrefly: ignore # not-iterable
|
||||
for var in self.ctx["variables"]:
|
||||
var.calculated_value = None
|
||||
state = {}
|
||||
@ -112,11 +114,13 @@ class Capture:
|
||||
return CaptureGetItem(self, key, ctx=self.ctx)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
|
||||
|
||||
def __add__(self, add_val):
|
||||
res = CaptureAdd(self, add_val, ctx=self.ctx)
|
||||
var = CaptureVariable(res, ctx=self.ctx)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.ctx["operations"].append(
|
||||
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
||||
)
|
||||
@ -125,6 +129,7 @@ class Capture:
|
||||
def __sub__(self, add_val):
|
||||
res = CaptureSub(self, add_val, ctx=self.ctx)
|
||||
var = CaptureVariable(res, ctx=self.ctx)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.ctx["operations"].append(
|
||||
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
||||
)
|
||||
@ -134,15 +139,19 @@ class Capture:
|
||||
res = CaptureMul(self, add_val, ctx=self.ctx)
|
||||
var = CaptureVariable(res, ctx=self.ctx)
|
||||
t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.ctx["operations"].append(t)
|
||||
return var
|
||||
|
||||
def _is_context_empty(self):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
|
||||
|
||||
def apply_ops_2(self, dataframe):
|
||||
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.ctx["variables"][0].calculated_value = dataframe
|
||||
# pyrefly: ignore # not-iterable
|
||||
for op in self.ctx["operations"]:
|
||||
op.execute()
|
||||
|
||||
@ -175,6 +184,7 @@ class Capture:
|
||||
res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
|
||||
var = CaptureVariable(None, ctx=self.ctx)
|
||||
t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.ctx["operations"].append(t)
|
||||
return var
|
||||
|
||||
@ -273,7 +283,9 @@ class CaptureVariable(Capture):
|
||||
|
||||
def apply_ops(self, dataframe):
|
||||
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.ctx["variables"][0].calculated_value = dataframe
|
||||
# pyrefly: ignore # not-iterable
|
||||
for op in self.ctx["operations"]:
|
||||
op.execute()
|
||||
return self.calculated_value
|
||||
@ -373,6 +385,7 @@ def get_val(capture):
|
||||
|
||||
class CaptureInitial(CaptureVariable):
|
||||
def __init__(self, schema_df=None):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
new_ctx: dict[str, list[Any]] = {
|
||||
"operations": [],
|
||||
"variables": [],
|
||||
@ -388,6 +401,7 @@ class CaptureDataFrame(CaptureInitial):
|
||||
|
||||
class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
|
||||
def as_datapipe(self):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)
|
||||
|
||||
def raw_iterator(self):
|
||||
|
@ -92,6 +92,7 @@ class FilterDataFramesPipe(DFIterDataPipe):
|
||||
size = None
|
||||
all_buffer = []
|
||||
filter_res = []
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for df in self.source_datapipe:
|
||||
if size is None:
|
||||
size = len(df.index)
|
||||
|
@ -135,6 +135,7 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
||||
_fast_forward_iterator: Optional[Iterator] = None
|
||||
|
||||
def __iter__(self) -> Iterator[_T_co]:
|
||||
# pyrefly: ignore # bad-return
|
||||
return self
|
||||
|
||||
def __getattr__(self, attribute_name):
|
||||
@ -379,6 +380,7 @@ class _DataPipeSerializationWrapper:
|
||||
value = pickle.dumps(self._datapipe)
|
||||
except Exception:
|
||||
if HAS_DILL:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
value = dill.dumps(self._datapipe)
|
||||
use_dill = True
|
||||
else:
|
||||
@ -388,6 +390,7 @@ class _DataPipeSerializationWrapper:
|
||||
def __setstate__(self, state):
|
||||
value, use_dill = state
|
||||
if use_dill:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self._datapipe = dill.loads(value)
|
||||
else:
|
||||
self._datapipe = pickle.loads(value)
|
||||
@ -404,6 +407,7 @@ class _DataPipeSerializationWrapper:
|
||||
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
|
||||
def __init__(self, datapipe: IterDataPipe[_T_co]):
|
||||
super().__init__(datapipe)
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
self._datapipe_iter: Optional[Iterator[_T_co]] = None
|
||||
|
||||
def __iter__(self) -> "_IterDataPipeSerializationWrapper":
|
||||
|
@ -118,6 +118,7 @@ class MapperIterDataPipe(IterDataPipe[_T_co]):
|
||||
for idx in sorted(self.input_col[1:], reverse=True):
|
||||
del data[idx]
|
||||
else:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
data[self.input_col] = res
|
||||
else:
|
||||
if self.output_col == -1:
|
||||
|
@ -42,6 +42,7 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]):
|
||||
"Sampler class requires input datapipe implemented `__len__`"
|
||||
)
|
||||
super().__init__()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.datapipe = datapipe
|
||||
self.sampler_args = () if sampler_args is None else sampler_args
|
||||
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user