mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pyrefly suppressions 7/n (#164913)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
12d2ef557f
commit
c855f8632e
@ -22,14 +22,16 @@ project-excludes = [
|
||||
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
||||
"torch/_inductor/**",
|
||||
"torch/distributed/**",
|
||||
"torch/nn/**",
|
||||
"torch/_dynamo/**",
|
||||
# formatting issues
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
"torch/package/_package_pickler.py",
|
||||
"torch/jit/annotations.py",
|
||||
"torch/utils/data/datapipes/_typing.py",
|
||||
"torch/nn/functional.py",
|
||||
"torch/_export/utils.py",
|
||||
"torch/fx/experimental/unification/multipledispatch/__init__.py",
|
||||
"torch/nn/modules/__init__.py",
|
||||
# ====
|
||||
"benchmarks/instruction_counts/main.py",
|
||||
"benchmarks/instruction_counts/definitions/setup.py",
|
||||
|
@ -59,7 +59,6 @@ class TestBundledInputs(TestCase):
|
||||
# despite having nominally large bundled inputs.
|
||||
augmented_size = model_size(sm)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertLess(augmented_size, original_size + (1 << 12))
|
||||
|
||||
loaded = save_and_load(sm)
|
||||
@ -67,15 +66,12 @@ class TestBundledInputs(TestCase):
|
||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||
self.assertEqual(len(inflated), len(samples))
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||
|
||||
for idx, inp in enumerate(inflated):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertIsInstance(inp, tuple)
|
||||
self.assertEqual(len(inp), 1)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertIsInstance(inp[0], torch.Tensor)
|
||||
if idx != 5:
|
||||
# Strides might be important for benchmarking.
|
||||
@ -144,7 +140,6 @@ class TestBundledInputs(TestCase):
|
||||
inflated = loaded.get_all_bundled_inputs()
|
||||
self.assertEqual(inflated, samples)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(loaded(*inflated[0]) == "first 1")
|
||||
|
||||
def test_multiple_methods_with_inputs(self):
|
||||
@ -192,7 +187,6 @@ class TestBundledInputs(TestCase):
|
||||
|
||||
# Check running and size helpers
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||
|
||||
@ -426,7 +420,6 @@ class TestBundledInputs(TestCase):
|
||||
augmented_size = model_size(sm)
|
||||
# assert the size has not increased more than 8KB
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.assertLess(augmented_size, original_size + (1 << 13))
|
||||
|
||||
loaded = save_and_load(sm)
|
||||
|
@ -48,7 +48,7 @@ class TestComplexTensor(TestCase):
|
||||
def test_all(self, device, dtype):
|
||||
# issue: https://github.com/pytorch/pytorch/issues/120875
|
||||
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertTrue(torch.all(x))
|
||||
|
||||
@dtypes(*complex_types())
|
||||
@ -57,7 +57,7 @@ class TestComplexTensor(TestCase):
|
||||
x = torch.tensor(
|
||||
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertFalse(torch.any(x))
|
||||
|
||||
@onlyCPU
|
||||
|
@ -142,7 +142,6 @@ class TestTypeHints(TestCase):
|
||||
]
|
||||
)
|
||||
if result != 0:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
self.fail(f"mypy failed:\n{stderr}\n{stdout}")
|
||||
|
||||
|
||||
|
@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||
# If reference count is leaked this would be a set of 10 elements
|
||||
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertLess(len(ref_cnt), 3)
|
||||
|
||||
self.assertEqual(torch.float64.to_complex(), torch.complex128)
|
||||
@ -136,7 +136,7 @@ class TestDTypeInfo(TestCase):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||
# If reference count is leaked this would be a set of 10 elements
|
||||
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
|
||||
self.assertLess(len(ref_cnt), 3)
|
||||
|
||||
self.assertEqual(torch.complex128.to_real(), torch.double)
|
||||
|
@ -53,6 +53,8 @@ from .eval_frame import (
|
||||
OptimizedModule,
|
||||
reset_code,
|
||||
)
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
from .external_utils import is_compiling
|
||||
from .mutation_guard import GenerationTracker
|
||||
from .pgo import reset_code_state
|
||||
|
@ -95,6 +95,7 @@ class ModIndex(torch.autograd.Function):
|
||||
generate_vmap_rule = True
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(x: Tensor, indices: list[Tensor]) -> Tensor:
|
||||
return torch.ops.aten.index(x, indices)
|
||||
|
||||
@ -242,6 +243,7 @@ def _trace_wrapped_functionalized(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
|
||||
def autograd_function_backward_rewritten(original_backward: Any) -> Any:
|
||||
def new_backward(ctx: Any, *grads: Any) -> Any:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
grads = [g.contiguous() for g in grads]
|
||||
return original_backward(ctx, *grads)
|
||||
|
||||
|
@ -89,6 +89,7 @@ class AOTCompiledFunction:
|
||||
**import_sources,
|
||||
self._artifacts.backend_id: self._artifacts.compiled_fn,
|
||||
}
|
||||
# pyrefly: ignore # read-only
|
||||
self.fn = types.FunctionType(
|
||||
self._artifacts.bytecode, f_globals, closure=self._artifacts.closure
|
||||
)
|
||||
|
@ -206,6 +206,7 @@ def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any])
|
||||
assert manager is not None
|
||||
|
||||
def fn(inputs: list[Any]) -> Any:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
manager.set_to_running_backward()
|
||||
return aot_model(inputs)
|
||||
|
||||
|
@ -77,16 +77,19 @@ def tvm(
|
||||
opt_level = options.get("opt_level", 3)
|
||||
|
||||
if scheduler == "auto_scheduler":
|
||||
# pyrefly: ignore # import-error
|
||||
from tvm import auto_scheduler
|
||||
|
||||
log_file = tempfile.NamedTemporaryFile()
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if not os.path.exists(log_file):
|
||||
tasks, task_weights = auto_scheduler.extract_tasks(
|
||||
mod["main"], params, target
|
||||
)
|
||||
if len(tasks) != 0:
|
||||
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if not os.path.exists(log_file):
|
||||
assert trials > 0
|
||||
tune_option = auto_scheduler.TuningOptions(
|
||||
@ -97,7 +100,9 @@ def tvm(
|
||||
try:
|
||||
tuner.tune(tune_option)
|
||||
except Exception:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if os.path.exists(log_file):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
os.unlink(log_file)
|
||||
raise
|
||||
|
||||
@ -107,6 +112,7 @@ def tvm(
|
||||
):
|
||||
lib = relay.build(mod, target=target, params=params)
|
||||
elif scheduler == "meta_schedule":
|
||||
# pyrefly: ignore # import-error
|
||||
from tvm import meta_schedule as ms
|
||||
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
|
@ -37,6 +37,7 @@ if sys.version_info >= (3, 11):
|
||||
TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
|
||||
else:
|
||||
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if (3, 12) <= sys.version_info < (3, 14):
|
||||
TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
|
||||
if sys.version_info >= (3, 13):
|
||||
|
@ -903,6 +903,7 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None:
|
||||
inst.arg = abs(
|
||||
int(target.offset - inst.offset - instruction_size(inst))
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
inst.arg //= 2
|
||||
inst.argval = target.offset
|
||||
inst.argrepr = f"to {target.offset}"
|
||||
@ -1354,6 +1355,7 @@ def update_offsets(instructions: Sequence[Instruction]) -> None:
|
||||
offset = 0
|
||||
for inst in instructions:
|
||||
inst.offset = offset
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
offset += instruction_size(inst)
|
||||
|
||||
|
||||
|
@ -464,6 +464,7 @@ def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
try:
|
||||
prof.enable()
|
||||
start_ts = time.time()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
retval = prof.runcall(func, *args, **kwargs)
|
||||
profile_latency = time.time() - start_ts
|
||||
prof.disable()
|
||||
@ -957,6 +958,7 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
||||
if isinstance(mod, torch.nn.Module):
|
||||
mod = mod.forward
|
||||
if hasattr(mod, "__self__"):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return mod.__func__, mod.__self__
|
||||
elif inspect.isfunction(mod):
|
||||
return mod, None
|
||||
@ -1096,6 +1098,7 @@ def _fullgraph_capture_frame(
|
||||
while cur_exn.__cause__ is not None:
|
||||
cur_exn.__cause__.with_traceback(None)
|
||||
cur_exn = cur_exn.__cause__
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
raise e.with_traceback(None) from e.__cause__ # User compiler error
|
||||
|
||||
return CaptureOutput(
|
||||
@ -1119,6 +1122,7 @@ def compile_frame( # type: ignore[return]
|
||||
frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
|
||||
distributed_state: Optional[DistributedState] = None,
|
||||
package: Optional[CompilePackage] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> DynamoOutput:
|
||||
"""
|
||||
A helper function taking a frame and backend, then return the generated bytecode
|
||||
|
@ -20,6 +20,7 @@ allowed to compute gradients on).
|
||||
|
||||
class TracableCreateParameter(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter:
|
||||
assert not tensor.requires_grad
|
||||
return placeholder.set_(tensor)
|
||||
|
@ -879,6 +879,7 @@ def aot_graph_input_parser(
|
||||
data_type, shape_str = match.groups()
|
||||
shape = tuple(shape_str.split(","))
|
||||
dtype = dtype_map[data_type]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
kwargs[param] = gen_tensor(shape, dtype)
|
||||
|
||||
match = re.search(sym_shape_regex, annotation)
|
||||
@ -892,6 +893,7 @@ def aot_graph_input_parser(
|
||||
attr_name, data_type, shape_str, _ = match.groups()
|
||||
shape = tuple(shape_str.split(","))
|
||||
dtype = dtype_map[data_type]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
setattr(container, attr_name, gen_tensor(shape, dtype))
|
||||
|
||||
return kwargs
|
||||
|
@ -95,6 +95,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig
|
||||
nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined]
|
||||
nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined]
|
||||
nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
|
||||
# pyrefly: ignore # bad-return
|
||||
return nonrecursive_disable_wrapper
|
||||
|
||||
if fn is None:
|
||||
@ -306,6 +307,7 @@ def forbid_in_graph(fn: Any) -> Any:
|
||||
if isinstance(fn, (list, tuple)):
|
||||
return [forbid_in_graph(x) for x in fn]
|
||||
assert callable(fn), "forbid_in_graph applies only to callables"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
fn._dynamo_forbidden = True
|
||||
return fn
|
||||
|
||||
@ -653,21 +655,28 @@ def mark_dynamic(
|
||||
|
||||
if isinstance(index, int):
|
||||
if not hasattr(t, "_dynamo_dynamic_indices"):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._dynamo_dynamic_indices = set()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._dynamo_dynamic_range = set()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._dynamo_hint_overrides = {}
|
||||
|
||||
if not hasattr(t, "_specialize_on"):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._specialize_on = {}
|
||||
|
||||
if hint_override:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._dynamo_hint_overrides[index] = hint_override
|
||||
# TODO(voz): Should we bounds check?
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._dynamo_dynamic_indices.add(index)
|
||||
t._dynamo_dynamic_range.add(_DimRange(index, min, max)) # type: ignore[arg-type]
|
||||
|
||||
# FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies:
|
||||
# TypeError: 'Attribute' object does not support item assignment
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if isinstance(t._specialize_on, dict):
|
||||
t._specialize_on[index] = specialize_on if specialize_on is not None else []
|
||||
|
||||
@ -692,8 +701,10 @@ def maybe_mark_dynamic(t: Any, index: Union[int, list[Any], tuple[Any]]) -> None
|
||||
|
||||
if isinstance(index, int):
|
||||
if not hasattr(t, "_dynamo_weak_dynamic_indices"):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._dynamo_weak_dynamic_indices = set()
|
||||
# TODO(voz): Should we bounds check?
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._dynamo_weak_dynamic_indices.add(index)
|
||||
return
|
||||
|
||||
@ -745,8 +756,11 @@ def mark_static(
|
||||
# TODO: Make this configurable via a supported public API
|
||||
_apply_func_to_inner_tensors_of_same_dim(mark_static, t, index)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._dynamo_marked_static = True
|
||||
# pyrefly: ignore # bad-return
|
||||
return t
|
||||
|
||||
if not isinstance(t, torch.Tensor):
|
||||
|
@ -205,6 +205,7 @@ class CudaInterface(DeviceInterface):
|
||||
Event = torch.cuda.Event # type: ignore[assignment]
|
||||
Stream = torch.cuda.Stream # type: ignore[assignment]
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def set_device(device: int) -> None:
|
||||
@ -240,6 +241,7 @@ class CudaInterface(DeviceInterface):
|
||||
set_device = staticmethod(torch.cuda.set_device)
|
||||
device_count = staticmethod(torch.cuda.device_count)
|
||||
stream = staticmethod(torch.cuda.stream) # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-override
|
||||
current_stream = staticmethod(torch.cuda.current_stream)
|
||||
set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment]
|
||||
_set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment]
|
||||
@ -300,6 +302,7 @@ class MtiaInterface(DeviceInterface):
|
||||
Event = torch.mtia.Event # type: ignore[assignment]
|
||||
Stream = torch.mtia.Stream # type: ignore[assignment]
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def set_device(device: int) -> None:
|
||||
@ -335,6 +338,7 @@ class MtiaInterface(DeviceInterface):
|
||||
set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment]
|
||||
device_count = staticmethod(torch.mtia.device_count)
|
||||
stream = staticmethod(torch.mtia.stream) # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-override
|
||||
current_stream = staticmethod(torch.mtia.current_stream)
|
||||
set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment]
|
||||
_set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment]
|
||||
@ -381,6 +385,7 @@ class XpuInterface(DeviceInterface):
|
||||
Event = torch.xpu.Event # type: ignore[assignment]
|
||||
Stream = torch.xpu.Stream # type: ignore[assignment]
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def set_device(device: int) -> None:
|
||||
@ -416,6 +421,7 @@ class XpuInterface(DeviceInterface):
|
||||
set_device = staticmethod(torch.xpu.set_device)
|
||||
device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type]
|
||||
stream = staticmethod(torch.xpu.stream) # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-override
|
||||
current_stream = staticmethod(torch.xpu.current_stream)
|
||||
set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment]
|
||||
_set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment]
|
||||
@ -458,6 +464,7 @@ class CpuDeviceProperties:
|
||||
|
||||
|
||||
class CpuInterface(DeviceInterface):
|
||||
# pyrefly: ignore # bad-override
|
||||
class Event(torch.Event):
|
||||
def __init__(self, enable_timing: bool = True) -> None:
|
||||
self.time = 0.0
|
||||
@ -468,6 +475,7 @@ class CpuInterface(DeviceInterface):
|
||||
def record(self, stream: Any = None) -> None:
|
||||
self.time = time.perf_counter()
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def get_device_properties(
|
||||
@ -543,6 +551,7 @@ class MpsInterface(DeviceInterface):
|
||||
def synchronize(device: torch.types.Device = None) -> None:
|
||||
torch.mps.synchronize()
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None) -> Any:
|
||||
|
@ -484,6 +484,7 @@ class OptimizedModule(torch.nn.Module):
|
||||
self._initialize()
|
||||
|
||||
@property
|
||||
# pyrefly: ignore # bad-override
|
||||
def training(self) -> bool:
|
||||
return self._orig_mod.training
|
||||
|
||||
@ -892,6 +893,7 @@ class _TorchDynamoContext:
|
||||
while cur_exn.__cause__ is not None:
|
||||
cur_exn.__cause__.with_traceback(None)
|
||||
cur_exn = cur_exn.__cause__
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
raise e.with_traceback(None) from e.__cause__ # User compiler error
|
||||
except ShortenTraceback as e:
|
||||
# Failures in the backend likely don't have useful
|
||||
@ -1020,7 +1022,10 @@ class OptimizeContext(_TorchDynamoContext):
|
||||
assert rebuild_ctx is not None
|
||||
compiler_fn = rebuild_ctx()
|
||||
ctx = torch._dynamo.compiled_autograd._enable(
|
||||
compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False
|
||||
compiler_fn,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dynamic=_dynamic,
|
||||
ignore_active_disable_ctx=False,
|
||||
)
|
||||
ctx.__enter__()
|
||||
return functools.partial(ctx.__exit__, None, None, None)
|
||||
@ -1083,6 +1088,7 @@ class DisableContext(_TorchDynamoContext):
|
||||
cls_obj.__call__ = self(cls_obj.__call__)
|
||||
if issubclass(cls_obj, torch.nn.Module):
|
||||
# NN module variable tracker directly inlines the _call_impl. Disable it.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
cls_obj._call_impl = self(cls_obj._call_impl)
|
||||
return cls_obj
|
||||
|
||||
@ -1988,6 +1994,7 @@ def export(
|
||||
path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any]
|
||||
) -> Any:
|
||||
if isinstance(t, torch.Tensor):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return ambient_fake_mode.from_tensor(t, static_shapes=True)
|
||||
elif isinstance(t, _IntWrapper):
|
||||
if (
|
||||
@ -2068,8 +2075,11 @@ def export(
|
||||
)
|
||||
and not trace_rules.check(call_to_inspect)
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
dim_constraints.solve()
|
||||
# pyrefly: ignore # unbound-name
|
||||
forced_specializations = dim_constraints.forced_specializations()
|
||||
# pyrefly: ignore # unbound-name
|
||||
msg = dim_constraints.prettify_results(
|
||||
original_signature,
|
||||
dynamic_shapes,
|
||||
@ -2090,9 +2100,11 @@ def export(
|
||||
)
|
||||
|
||||
# Error if we have any constraints on static values
|
||||
# pyrefly: ignore # unbound-name
|
||||
for k in shape_env.var_to_range.keys():
|
||||
if isinstance(k, sympy.Integer):
|
||||
constraint_violation_error = ConstraintViolationError(
|
||||
# pyrefly: ignore # unbound-name
|
||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||
"It appears that you're trying to set a constraint on a "
|
||||
f"value which we evaluated to have a static value of {k}. "
|
||||
|
@ -369,6 +369,7 @@ def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedExc
|
||||
observed_exception_map[exc_type] = type( # type: ignore[assignment]
|
||||
f"Observed{name}Error", (ObservedException,), {}
|
||||
)
|
||||
# pyrefly: ignore # index-error
|
||||
return observed_exception_map[exc_type]
|
||||
|
||||
|
||||
|
@ -96,7 +96,9 @@ def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
args, kwargs = pytree.tree_map_only(
|
||||
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
|
||||
)
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
out = f(*args, **kwargs)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
|
||||
|
||||
return wrap
|
||||
|
@ -250,6 +250,7 @@ class DynamoGraphTransformer(torch.fx.Transformer):
|
||||
else:
|
||||
placeholder.node.meta["val"] = self.flat_inputs[i]
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.new_input_nodes[i] = placeholder
|
||||
|
||||
def _create_placeholder_mapping(self) -> None:
|
||||
@ -324,12 +325,18 @@ class DynamoGraphTransformer(torch.fx.Transformer):
|
||||
|
||||
# Copy module metadata like the original implementation
|
||||
if hasattr(self.module, "meta"):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if "dynamo_flat_name_to_original_fqn" in self.module.meta:
|
||||
# pyrefly: ignore # index-error
|
||||
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
|
||||
# pyrefly: ignore # index-error
|
||||
"dynamo_flat_name_to_original_fqn"
|
||||
]
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if "dynamo_compile_id" in self.module.meta:
|
||||
# pyrefly: ignore # index-error
|
||||
result_gm.meta["dynamo_compile_id"] = self.module.meta[
|
||||
# pyrefly: ignore # index-error
|
||||
"dynamo_compile_id"
|
||||
]
|
||||
|
||||
@ -361,8 +368,11 @@ def _suggest_or_raise_constraint_violation(
|
||||
torch._ops.OpOverloadPacket | torch._ops.OpOverload,
|
||||
)
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
dim_constraints.solve()
|
||||
# pyrefly: ignore # unbound-name
|
||||
forced_specializations = dim_constraints.forced_specializations()
|
||||
# pyrefly: ignore # unbound-name
|
||||
msg = dim_constraints.prettify_results(
|
||||
inspect.signature(orig_callable), # type: ignore[attr-defined]
|
||||
dynamic_shapes,
|
||||
@ -383,9 +393,11 @@ def _suggest_or_raise_constraint_violation(
|
||||
)
|
||||
|
||||
# Error if we have any constraints on static values
|
||||
# pyrefly: ignore # unbound-name
|
||||
for k in shape_env.var_to_range.keys():
|
||||
if isinstance(k, sympy.Integer):
|
||||
constraint_violation_error = ConstraintViolationError(
|
||||
# pyrefly: ignore # unbound-name
|
||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||
"It appears that you're trying to set a constraint on a "
|
||||
f"value which we evaluated to have a static value of {k}. "
|
||||
|
@ -320,6 +320,7 @@ class GraphRegionTracker:
|
||||
if len(group) > 1:
|
||||
region_group = []
|
||||
min_rank = math.inf
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for node in group:
|
||||
# some nodes aren't in the topo ranking?
|
||||
if node in topological_ranking:
|
||||
|
@ -640,6 +640,7 @@ class GuardManagerWrapper:
|
||||
if isinstance(guard, RelationalGuard):
|
||||
if guard not in self.printed_relational_guards:
|
||||
self.printed_relational_guards.add(guard)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
body.writelines(self.get_guard_lines(guard))
|
||||
else:
|
||||
body.writelines(
|
||||
@ -700,6 +701,7 @@ class GuardManagerWrapper:
|
||||
for guard in mgr.get_leaf_guards():
|
||||
if isinstance(guard, RelationalGuard):
|
||||
if guard not in relational_guards_seen:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.code_parts.extend(get_code_parts(guard))
|
||||
relational_guards_seen.add(guard)
|
||||
else:
|
||||
@ -716,6 +718,7 @@ def from_numpy(a: Any) -> torch.Tensor:
|
||||
# Re-enable torch function since we disable it on leaf guards
|
||||
# we need it to properly construct the tensor if a default device is set
|
||||
with torch.overrides._enable_torch_function():
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a
|
||||
|
||||
|
||||
@ -729,6 +732,7 @@ def uninteresting_files() -> set[str]:
|
||||
|
||||
from torch._dynamo.polyfills.loader import POLYFILLED_MODULES
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
mods.extend(POLYFILLED_MODULES)
|
||||
|
||||
return {inspect.getfile(m) for m in mods}
|
||||
@ -2205,6 +2209,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
return
|
||||
|
||||
# Python math library doesn't support complex nan, so we need to use numpy
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if istype(val, complex) and np.isnan(val):
|
||||
code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"]
|
||||
self._set_guard_export_info(guard, code)
|
||||
@ -2495,6 +2500,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
# sources for the corresponding tensor dimension.
|
||||
return [
|
||||
TensorPropertySource(source, TensorProperty.SIZE, dim)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for source in output_graph.tracked_fakes_id_to_source[t_id]
|
||||
]
|
||||
|
||||
@ -2531,6 +2537,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
equalities_inputs = None
|
||||
|
||||
def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return output_graph.shape_env.produce_guards_verbose(
|
||||
[a.fake for a in fs], # type: ignore[misc]
|
||||
[a.source for a in fs],
|
||||
@ -2538,6 +2545,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
equalities_inputs=equalities_inputs,
|
||||
source_ref=self.source_ref,
|
||||
# Export keeps static.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
ignore_static=(not output_graph.export),
|
||||
langs=langs,
|
||||
)
|
||||
@ -2599,7 +2607,9 @@ class GuardBuilder(GuardBuilderBase):
|
||||
if not python_fallback:
|
||||
assert cpp_code_parts # type: ignore[possibly-undefined]
|
||||
code_parts, source_to_symbol = (
|
||||
# pyrefly: ignore # unbound-name
|
||||
cpp_code_parts.exprs,
|
||||
# pyrefly: ignore # unbound-name, missing-attribute
|
||||
cpp_code_parts.source_to_symbol,
|
||||
)
|
||||
|
||||
@ -2630,7 +2640,9 @@ class GuardBuilder(GuardBuilderBase):
|
||||
|
||||
assert cpp_code_parts # type: ignore[possibly-undefined]
|
||||
code_parts, source_to_symbol = (
|
||||
# pyrefly: ignore # unbound-name
|
||||
cpp_code_parts.exprs,
|
||||
# pyrefly: ignore # unbound-name, missing-attribute
|
||||
cpp_code_parts.source_to_symbol,
|
||||
)
|
||||
|
||||
@ -3240,6 +3252,7 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
assert _.__closure__ is not None
|
||||
return _.__closure__[0]
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def reducer_override(
|
||||
self, obj: Any
|
||||
) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]:
|
||||
@ -4072,10 +4085,13 @@ class CheckFunctionManager:
|
||||
and (cache_entry := self.guard_manager.cache_entry) is not None
|
||||
and (extra_state := self.guard_manager.extra_state) is not None
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
assert isinstance(cache_entry, CacheEntry)
|
||||
# pyrefly: ignore # unbound-name
|
||||
assert isinstance(extra_state, ExtraState)
|
||||
reason = f"Cache line invalidated because {obj_str} got deallocated"
|
||||
deleted_guard_manager = DeletedGuardManagerWrapper(reason)
|
||||
# pyrefly: ignore # unbound-name
|
||||
extra_state.invalidate(cache_entry, deleted_guard_manager)
|
||||
self.guard_manager = deleted_guard_manager
|
||||
|
||||
|
@ -707,6 +707,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
self.backward_state_proxy: Optional[torch.fx.Proxy] = None
|
||||
self.backward_state_var: Optional[str] = None
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
self.name_of_builtins_dict_key_in_fglobals: str = (
|
||||
self.install_builtins_dict_in_fglobals()
|
||||
)
|
||||
@ -1146,6 +1147,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
|
||||
|
||||
assert "tensor_dict" not in vt.as_proxy().node.meta
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target)
|
||||
|
||||
return vt
|
||||
@ -1157,6 +1159,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
install_guard(source.make_guard(GuardBuilder.NN_MODULE))
|
||||
|
||||
def wrap_name(module_key: str) -> VariableTracker:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return NNModuleVariable(type(target), module_key, target, **options)
|
||||
|
||||
else:
|
||||
@ -1970,7 +1973,9 @@ class OutputGraph(OutputGraphCommon):
|
||||
tx = self.root_tx
|
||||
assert tx is not None
|
||||
if (ds := tx.distributed_state) is not None and ds.all_states is None:
|
||||
# pyrefly: ignore # unbound-name
|
||||
compile_pg = ds.compile_pg
|
||||
# pyrefly: ignore # unbound-name
|
||||
log.info("compiler_collective %s", ds.local_state)
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
@ -1978,6 +1983,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
"name": "compiler_collective",
|
||||
"encoding": "string",
|
||||
},
|
||||
# pyrefly: ignore # unbound-name
|
||||
payload_fn=lambda: ds.local_state.render(),
|
||||
)
|
||||
device_types = compile_pg._device_types
|
||||
@ -1991,7 +1997,9 @@ class OutputGraph(OutputGraphCommon):
|
||||
dynamo_timed("compiler_collective", log_pt2_compile_event=True),
|
||||
):
|
||||
all_states: list[Any] = [None] * compile_pg.size()
|
||||
# pyrefly: ignore # unbound-name
|
||||
dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
|
||||
# pyrefly: ignore # unbound-name
|
||||
ds.all_states = all_states
|
||||
# Clear speculation log, because are tracing may diverge due to
|
||||
# this information from the compiler collective
|
||||
@ -2321,6 +2329,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
},
|
||||
)
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
return compiled_fn
|
||||
|
||||
def dedup_pass(self) -> dict[str, torch.fx.GraphModule]:
|
||||
@ -2375,6 +2384,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
isinstance(b, torch.SymBool)
|
||||
and (r := b.node.maybe_as_bool()) is not None
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
return r
|
||||
# TODO: We can also technically remove all cases when the input
|
||||
# doesn't have unbacked inputs, since it's all in the ShapeEnv
|
||||
@ -2740,6 +2750,7 @@ def check_pt2_compliant_op(
|
||||
hints=[],
|
||||
)
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
op = getattr(target, overload)
|
||||
if torch.Tag.pt2_compliant_tag in op.tags:
|
||||
encountered_compliant_op(op)
|
||||
@ -2747,6 +2758,7 @@ def check_pt2_compliant_op(
|
||||
encountered_non_compliant_op(
|
||||
op,
|
||||
f"Encountered the torch.ops.OpOverloadPacket {target} "
|
||||
# pyrefly: ignore # unbound-name
|
||||
f"which resolves to the overload ({overload}) that is "
|
||||
f"not PT2 compliant.",
|
||||
)
|
||||
@ -2767,6 +2779,7 @@ class LazyProxy:
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
self.tracer = tracer
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
self.fn = fn
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
@ -319,6 +319,7 @@ def _get_code_source(code: types.CodeType) -> tuple[str, str]:
|
||||
code_source = _find_code_source(toplevel)
|
||||
if code_source is None:
|
||||
_raise_resolution_error(code, toplevel)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return toplevel.__qualname__, code_source.strip(".")
|
||||
|
||||
|
||||
@ -593,9 +594,11 @@ class CompilePackage:
|
||||
f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})"
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._source_info = dynamo.source_info
|
||||
|
||||
main, *codes = dynamo.codes
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._codes = {self._innermost_fn.__code__: main}
|
||||
for code in codes:
|
||||
self._codes[SerializedCode.to_code_object(code.python_code)] = code
|
||||
@ -603,6 +606,7 @@ class CompilePackage:
|
||||
self._add_function(
|
||||
self._innermost_fn.__code__, self._innermost_fn.__module__
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._initialized = True
|
||||
|
||||
def _add_function(
|
||||
@ -746,6 +750,7 @@ class CompilePackage:
|
||||
for name in names:
|
||||
module.__dict__.pop(name)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._installed_globals = {}
|
||||
|
||||
_reset_precompile_entries(self._innermost_fn.__code__)
|
||||
|
@ -167,6 +167,7 @@ class CodeId:
|
||||
@dataclasses.dataclass
|
||||
class CodeState:
|
||||
automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field(
|
||||
# pyrefly: ignore # unbound-name
|
||||
default_factory=lambda: defaultdict(FrameStateSizeEntry)
|
||||
)
|
||||
|
||||
@ -851,6 +852,7 @@ def get_code_state() -> defaultdict[CodeId, CodeState]:
|
||||
not _CODE_STATE
|
||||
and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
|
||||
):
|
||||
# pyrefly: ignore # unbound-name
|
||||
extra_read_key = get_extra_cache_key(sticky_read)
|
||||
if extra_read_key is not None:
|
||||
get_extra_remote_code_state(extra_read_key)
|
||||
|
@ -196,6 +196,7 @@ def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def zip_longest(
|
||||
iter1: Iterable[_T1],
|
||||
/,
|
||||
@ -205,6 +206,7 @@ def zip_longest(
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def zip_longest(
|
||||
iter1: Iterable[_T1],
|
||||
iter2: Iterable[_T2],
|
||||
@ -213,6 +215,7 @@ def zip_longest(
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def zip_longest(
|
||||
iter1: Iterable[_T1],
|
||||
iter2: Iterable[_T2],
|
||||
@ -223,6 +226,7 @@ def zip_longest(
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def zip_longest(
|
||||
iter1: Iterable[_T],
|
||||
iter2: Iterable[_T],
|
||||
@ -233,6 +237,7 @@ def zip_longest(
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def zip_longest(
|
||||
iter1: Iterable[_T],
|
||||
iter2: Iterable[_T],
|
||||
|
@ -30,10 +30,12 @@ _Us = TypeVarTuple("_Us")
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def attrgetter(attr: str, /) -> Callable[[Any], _U]: ...
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def attrgetter(
|
||||
attr1: str, attr2: str, /, *attrs: str
|
||||
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...
|
||||
@ -68,10 +70,12 @@ def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]:
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def itemgetter(item: _T, /) -> Callable[[Any], _U]: ...
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def itemgetter(
|
||||
item1: _T1, item2: _T2, /, *items: Unpack[_Ts]
|
||||
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...
|
||||
|
@ -17,6 +17,7 @@ __all__ = ["fspath"]
|
||||
@substitute_in_graph(os.fspath, can_constant_fold_through=True)
|
||||
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
|
||||
if isinstance(path, (str, bytes)):
|
||||
# pyrefly: ignore # bad-return
|
||||
return path
|
||||
|
||||
path_type = type(path)
|
||||
|
@ -171,6 +171,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
or optree.is_namedtuple_class(treespec.type)
|
||||
or optree.is_structseq_class(treespec.type)
|
||||
):
|
||||
# pyrefly: ignore # bad-return
|
||||
return treespec._unflatten_func(
|
||||
treespec._metadata,
|
||||
children_representations,
|
||||
|
@ -49,8 +49,11 @@ class ProfileMetrics:
|
||||
if isinstance(other, int):
|
||||
other = ProfileMetrics(other, other, other)
|
||||
return ProfileMetrics(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
self.microseconds / max(1, other.microseconds),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.operators / max(1, other.operators),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.fusions / max(1, other.fusions),
|
||||
)
|
||||
|
||||
|
@ -370,13 +370,16 @@ isolate_fails_code_str = None
|
||||
|
||||
try:
|
||||
if isinstance(kernel, Autotuner):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if isinstance(kernel.fn, Heuristics):
|
||||
model_str += "ERROR: Repro will not work as intended, "
|
||||
model_str += "triton.runtime.autotuner.Heuristics is not currently supported\n"
|
||||
break
|
||||
|
||||
config_strs = []
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for kernel_config in kernel.configs:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
config_strs.append(f"""triton.Config(
|
||||
{str(kernel_config.kwargs)},
|
||||
num_warps={kernel_config.num_warps},
|
||||
@ -394,8 +397,10 @@ isolate_fails_code_str = None
|
||||
""").strip()
|
||||
|
||||
model_str += "\n@triton.jit\n"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src
|
||||
fn_name = (
|
||||
# pyrefly: ignore # missing-attribute
|
||||
kernel._fn_name
|
||||
if isinstance(kernel, JITFunction)
|
||||
else kernel.fn._fn_name
|
||||
@ -409,7 +414,9 @@ isolate_fails_code_str = None
|
||||
model_str += "ERROR: Repro will not work as intended, "
|
||||
model_str += f"User defined triton kernel exception: {e}\n"
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
if len(kernel_side_table.constant_args) > 0:
|
||||
# pyrefly: ignore # unbound-name
|
||||
model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n"
|
||||
|
||||
model_str += NNModuleToString.convert(gm)
|
||||
@ -420,8 +427,10 @@ isolate_fails_code_str = None
|
||||
# Extract from graph placeholders and their corresponding arguments
|
||||
placeholder_targets = fx_placeholder_targets(gm)
|
||||
for placeholder, arg in zip(placeholder_targets, args):
|
||||
# pyrefly: ignore # unbound-name
|
||||
if isinstance(arg, (int, torch.SymInt)):
|
||||
writer.symint(placeholder, arg)
|
||||
# pyrefly: ignore # unbound-name
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
# TODO: improve these names with FQN
|
||||
writer.tensor(placeholder, arg)
|
||||
@ -431,16 +440,20 @@ isolate_fails_code_str = None
|
||||
writer.unsupported(placeholder, arg)
|
||||
|
||||
# Extract symbolic variables from the same arguments
|
||||
# pyrefly: ignore # unbound-name
|
||||
if isinstance(arg, torch.SymInt):
|
||||
sym_name = str(arg.node)
|
||||
if arg.node.hint is not None:
|
||||
used_syms[sym_name] = arg.node.hint
|
||||
# pyrefly: ignore # unbound-name
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
# Extract symbolic variables from tensor shapes and strides
|
||||
for dim in arg.shape:
|
||||
# pyrefly: ignore # unbound-name
|
||||
if isinstance(dim, torch.SymInt) and dim.node.hint is not None:
|
||||
used_syms[str(dim.node)] = dim.node.hint
|
||||
for stride in arg.stride():
|
||||
# pyrefly: ignore # unbound-name
|
||||
if isinstance(stride, torch.SymInt) and stride.node.hint is not None:
|
||||
used_syms[str(stride.node)] = stride.node.hint
|
||||
|
||||
@ -758,6 +771,7 @@ def repro_common(
|
||||
# TODO: speed this up
|
||||
mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
torch._inductor.config.generate_intermediate_hooks = True
|
||||
|
||||
return mod, args
|
||||
|
@ -301,6 +301,7 @@ def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]:
|
||||
def repro_common(
|
||||
options: Any, exported_program: ExportedProgram
|
||||
) -> tuple[torch.fx.GraphModule, Any, Any]:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
torch._inductor.config.generate_intermediate_hooks = True
|
||||
mod = exported_program.module(check_guards=False)
|
||||
args, kwargs = exported_program.example_inputs
|
||||
@ -422,6 +423,7 @@ def repro_minify(
|
||||
) -> bool:
|
||||
# Need to export first so the in_spec and out_spec are populated
|
||||
tuple_inputs = tuple(flat_example_inputs)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
gm = export_for_aoti_minifier(
|
||||
gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error
|
||||
)
|
||||
|
@ -102,6 +102,7 @@ def _bytecode_from_template_with_split(
|
||||
def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None:
|
||||
# NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
|
||||
# on torch._dynamo.utils.
|
||||
# pyrefly: ignore # unknown-name
|
||||
global __import_torch_dot__dynamo_dot_utils
|
||||
try:
|
||||
dummy
|
||||
@ -555,6 +556,7 @@ class ContinueExecutionCache:
|
||||
|
||||
# remap original instructions' exception table entries
|
||||
if old_hook_target_remap:
|
||||
# pyrefly: ignore # unbound-name
|
||||
assert is_py311_plus
|
||||
for inst in instructions:
|
||||
if (
|
||||
|
@ -696,6 +696,7 @@ class SideEffects:
|
||||
cg.add_cache(var)
|
||||
var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
|
||||
elif var.source is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
var.source = LocalCellSource(var.local_name)
|
||||
elif isinstance(var, variables.TensorVariable):
|
||||
# NOTE: for historical reasons we never assigned local sources
|
||||
@ -732,6 +733,7 @@ class SideEffects:
|
||||
if isinstance(var, variables.UserDefinedObjectVariable):
|
||||
|
||||
def load_new_method() -> None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
assert var.base_cls_vt is not None
|
||||
cg(var.base_cls_vt) # type: ignore[attr-defined]
|
||||
cg.extend_output([cg.create_load_attr("__new__")])
|
||||
@ -978,7 +980,9 @@ class SideEffects:
|
||||
|
||||
elif self.is_attribute_mutation(var):
|
||||
if isinstance(
|
||||
var, variables.UserDefinedDictVariable
|
||||
var,
|
||||
variables.UserDefinedDictVariable,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
) and self.is_modified(var._dict_vt):
|
||||
# Do dict related update manually here. The store_attr
|
||||
# mutations will be applied later.
|
||||
@ -1011,6 +1015,7 @@ class SideEffects:
|
||||
]
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
cg(var._dict_vt, allow_cache=False) # Don't codegen via source
|
||||
cg.extend_output(
|
||||
[
|
||||
@ -1031,7 +1036,9 @@ class SideEffects:
|
||||
]
|
||||
)
|
||||
elif isinstance(
|
||||
var, variables.UserDefinedListVariable
|
||||
var,
|
||||
variables.UserDefinedListVariable,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
) and self.is_modified(var._list_vt):
|
||||
# Update the list to the updated items. Be careful in
|
||||
# calling the list methods and not the overridden methods.
|
||||
@ -1048,6 +1055,7 @@ class SideEffects:
|
||||
]
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
cg(var._list_vt, allow_cache=False) # Don't codegen via source
|
||||
cg.extend_output(
|
||||
[
|
||||
|
@ -563,6 +563,7 @@ def log_graph_break(
|
||||
)
|
||||
else:
|
||||
user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
user_stack = collapse_resume_frames(user_stack)
|
||||
user_stack_formatted = "".join(traceback.format_list(user_stack))
|
||||
user_stack_trace = (
|
||||
@ -1040,6 +1041,7 @@ class BytecodeDispatchTableMeta(type):
|
||||
op: getattr(cls, opname, functools.partial(_missing, opname))
|
||||
for opname, op in dis.opmap.items()
|
||||
}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
|
||||
|
||||
|
||||
@ -1788,13 +1790,17 @@ class InstructionTranslatorBase(
|
||||
source = self.import_source(module_name)
|
||||
|
||||
if self.exec_recorder:
|
||||
# pyrefly: ignore # unbound-name
|
||||
self.exec_recorder.add_local_mod(recorded_name, value)
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
if istype(value, (types.ModuleType, DummyModule)):
|
||||
# pyrefly: ignore # unbound-name
|
||||
self.push(PythonModuleVariable(value, source=source))
|
||||
else:
|
||||
unimplemented_v2(
|
||||
gb_type="Bad import result",
|
||||
# pyrefly: ignore # unbound-name
|
||||
context=typestr(value),
|
||||
explanation="Import result is not a Python module.",
|
||||
hints=[],
|
||||
@ -1873,6 +1879,7 @@ class InstructionTranslatorBase(
|
||||
exit, exc = self.popn(2)
|
||||
assert exc is None
|
||||
self.push(exc)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {}))
|
||||
|
||||
def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None:
|
||||
@ -2294,7 +2301,9 @@ class InstructionTranslatorBase(
|
||||
):
|
||||
return True
|
||||
elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
|
||||
exc_instance.fn, expected_type.fn
|
||||
exc_instance.fn,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
expected_type.fn,
|
||||
):
|
||||
return True
|
||||
|
||||
@ -2354,26 +2363,37 @@ class InstructionTranslatorBase(
|
||||
assert isinstance(null, NullVariable)
|
||||
|
||||
if not isinstance(
|
||||
argsvars, BaseListVariable
|
||||
# pyrefly: ignore # unbound-name
|
||||
argsvars,
|
||||
BaseListVariable,
|
||||
# pyrefly: ignore # unbound-name
|
||||
) and argsvars.has_force_unpack_var_sequence(self):
|
||||
# pyrefly: ignore # unbound-name
|
||||
argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
|
||||
|
||||
# Unpack for cases like fn(**obj) where obj is a map
|
||||
# pyrefly: ignore # unbound-name
|
||||
if isinstance(kwargsvars, UserDefinedObjectVariable):
|
||||
kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type]
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
if not isinstance(argsvars, BaseListVariable) or not isinstance(
|
||||
kwargsvars, ConstDictVariable
|
||||
# pyrefly: ignore # unbound-name
|
||||
kwargsvars,
|
||||
ConstDictVariable,
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Variadic function call with bad args/kwargs type",
|
||||
# pyrefly: ignore # unbound-name
|
||||
context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}",
|
||||
explanation="Expected args to be a list and kwargs to be a dict",
|
||||
hints=[*graph_break_hints.USER_ERROR],
|
||||
)
|
||||
|
||||
# Map to a dictionary of str -> VariableTracker
|
||||
# pyrefly: ignore # unbound-name, missing-attribute
|
||||
kwargsvars = kwargsvars.keys_as_python_constant()
|
||||
# pyrefly: ignore # unbound-name, missing-attribute
|
||||
self.call_function(fn, argsvars.items, kwargsvars)
|
||||
|
||||
@break_graph_if_unsupported(push=1)
|
||||
@ -2437,6 +2457,7 @@ class InstructionTranslatorBase(
|
||||
|
||||
def LOAD_ATTR(self, inst: Instruction) -> None:
|
||||
if sys.version_info >= (3, 12):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if inst.arg % 2:
|
||||
self.LOAD_METHOD(inst)
|
||||
return
|
||||
@ -3029,14 +3050,17 @@ class InstructionTranslatorBase(
|
||||
"(i.e. `a, b, c = d`).",
|
||||
hints=[*graph_break_hints.USER_ERROR],
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
if len(val) != inst.argval:
|
||||
unimplemented_v2(
|
||||
gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE",
|
||||
# pyrefly: ignore # unbound-name
|
||||
context=f"expected length: {inst.argval}, actual: {len(val)}",
|
||||
explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode "
|
||||
"(i.e. `a, b, c = d`) with unexpected length.",
|
||||
hints=[*graph_break_hints.DYNAMO_BUG],
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
for i in reversed(val):
|
||||
self.push(i)
|
||||
|
||||
@ -3409,9 +3433,13 @@ class InstructionTranslatorBase(
|
||||
args = [contents[1]]
|
||||
|
||||
if kw_names:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
args = args + contents[2 : -len(kw_names)]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
kwargs_list = contents[-len(kw_names) :]
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
kwargs = dict(zip(kw_names, kwargs_list))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
assert len(kwargs) == len(kw_names)
|
||||
else:
|
||||
args = args + contents[2:]
|
||||
@ -4118,6 +4146,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
and isinstance(tos, LocalGeneratorObjectVariable)
|
||||
):
|
||||
self.stack[-1] = ListIteratorVariable(
|
||||
# pyrefly: ignore # unbound-name
|
||||
tos.force_unpack_var_sequence(self),
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
@ -4188,6 +4217,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
"""Trace and inline a called method"""
|
||||
|
||||
symbolic_result: Optional[VariableTracker]
|
||||
# pyrefly: ignore # bad-override
|
||||
parent: InstructionTranslatorBase
|
||||
|
||||
@classmethod
|
||||
@ -4231,6 +4261,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
# trace through.
|
||||
if (
|
||||
hasattr(getattr(func, "fn", None), "_origin")
|
||||
# pyrefly: ignore # missing-attribute
|
||||
and func.fn._origin is produce_trampoline_autograd_apply
|
||||
):
|
||||
# Known sound
|
||||
@ -4305,12 +4336,14 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
tracing_ctx.previously_inlined_functions[code] = result
|
||||
|
||||
try:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
sub_locals = func.bind_args(parent, args, kwargs)
|
||||
except TypeError as e:
|
||||
# Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
|
||||
raise ArgsMismatchError( # noqa: B904
|
||||
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
|
||||
reason=str(e),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
|
||||
args=[arg.python_type() for arg in args],
|
||||
kwargs=kwargs,
|
||||
@ -4394,6 +4427,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
sub_locals,
|
||||
parent.symbolic_globals,
|
||||
parent.symbolic_torch_function_state,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
func,
|
||||
)
|
||||
return tracer
|
||||
|
@ -153,7 +153,9 @@ class CPythonTestCase(TestCase):
|
||||
assertTupleEqual = unittest.TestCase.assertTupleEqual
|
||||
assertSetEqual = unittest.TestCase.assertSetEqual
|
||||
assertDictEqual = polyfills.assert_dict_equal
|
||||
# pyrefly: ignore # bad-override
|
||||
assertRaises = unittest.TestCase.assertRaises
|
||||
# pyrefly: ignore # bad-override
|
||||
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
|
||||
assertWarns = unittest.TestCase.assertWarns
|
||||
assertWarnsRegex = unittest.TestCase.assertWarnsRegex
|
||||
@ -169,8 +171,10 @@ class CPythonTestCase(TestCase):
|
||||
) -> Callable[..., Any]:
|
||||
# We want to compile only the test function, excluding any setup code
|
||||
# from unittest
|
||||
|
||||
method = getattr(self, self._testMethodName)
|
||||
method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method)
|
||||
|
||||
setattr(self, self._testMethodName, method)
|
||||
return fn
|
||||
|
||||
|
@ -207,6 +207,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py"))
|
||||
with open(launch_file) as f:
|
||||
launch_code = f.read()
|
||||
|
||||
self.assertTrue(os.path.exists(launch_file))
|
||||
|
||||
args = ["python3", launch_file, "minify", *minifier_args]
|
||||
@ -218,6 +219,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
print("minifier stdout:", launch_proc.stdout.decode("utf-8"))
|
||||
stderr = launch_proc.stderr.decode("utf-8")
|
||||
print("minifier stderr:", stderr)
|
||||
|
||||
self.assertNotIn("Input graph did not fail the tester", stderr)
|
||||
|
||||
return launch_proc, launch_code
|
||||
@ -230,6 +232,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py"))
|
||||
with open(repro_file) as f:
|
||||
repro_code = f.read()
|
||||
|
||||
self.assertTrue(os.path.exists(repro_file))
|
||||
|
||||
repro_proc = self._maybe_subprocess_run(
|
||||
@ -296,11 +299,14 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
|
||||
if expected_error is None:
|
||||
# Just check that there was no error
|
||||
self.assertEqual(test_proc.returncode, 0)
|
||||
|
||||
self.assertIsNone(repro_dir)
|
||||
return None
|
||||
# NB: Intentionally do not test return code; we only care about
|
||||
# actually generating the repro, we don't have to crash
|
||||
|
||||
self.assertIn(expected_error, test_proc.stderr.decode("utf-8"))
|
||||
|
||||
self.assertIsNotNone(repro_dir)
|
||||
print("running minifier", file=sys.stderr)
|
||||
_minifier_proc, minifier_code = self._run_minifier_launcher(
|
||||
@ -311,6 +317,7 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
|
||||
)
|
||||
print("running repro", file=sys.stderr)
|
||||
repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate)
|
||||
|
||||
self.assertIn(expected_error, repro_proc.stderr.decode("utf-8"))
|
||||
self.assertNotEqual(repro_proc.returncode, 0)
|
||||
return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code)
|
||||
|
@ -496,6 +496,7 @@ def make_test_cls_with_patches(
|
||||
def skipIfNotPy311(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
if sys.version_info >= (3, 11):
|
||||
return fn
|
||||
# pyrefly: ignore # bad-return, bad-argument-type
|
||||
return unittest.skip(fn)
|
||||
|
||||
|
||||
|
@ -3005,6 +3005,7 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
|
||||
obj = torch_dir + k[len("torch/") :]
|
||||
if obj is not None:
|
||||
if is_annotate_wrapped_function(obj):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
obj = obj.__wrapped__
|
||||
if is_lru_cache_wrapped_function(obj):
|
||||
obj = obj.__wrapped__
|
||||
|
@ -295,11 +295,13 @@ def increment_op_count(cnt: int) -> None:
|
||||
def calculate_time_spent() -> dict[str, float]:
|
||||
total_by_key = {}
|
||||
for phase, timing in cumulative_time_spent_ns.items():
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
total_by_key[phase] = timing / 1e9
|
||||
|
||||
total_by_key["total_wall_time"] = total_by_key.get(
|
||||
"entire_frame_compile", 0
|
||||
) + total_by_key.get("entire_backward_compile", 0)
|
||||
# pyrefly: ignore # bad-return
|
||||
return total_by_key
|
||||
|
||||
|
||||
@ -798,6 +800,7 @@ def compile_times(repr: Literal["str"], aggregate: bool = False) -> str: ...
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def compile_times(
|
||||
repr: Literal["csv"], aggregate: bool = False
|
||||
) -> tuple[list[str], list[object]]: ...
|
||||
@ -1463,6 +1466,7 @@ class CompilationMetrics:
|
||||
compile_id = all_metrics.get("compile_id")
|
||||
all_metrics["compile_id"] = str(compile_id) if compile_id else None
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return cls(**all_metrics)
|
||||
|
||||
|
||||
@ -2253,6 +2257,7 @@ def is_jit_model(
|
||||
Union[
|
||||
torch.jit._trace.TopLevelTracedModule,
|
||||
torch.jit._script.RecursiveScriptModule,
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
torch.jit.ScriptFunction[Any, Any],
|
||||
torch.jit.ScriptModule,
|
||||
]
|
||||
@ -2361,6 +2366,7 @@ def checkpoint_params(gm: torch.fx.GraphModule) -> Callable[[], None]:
|
||||
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
|
||||
saved_state = [
|
||||
(param, param._version, torch.clone(param))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
for param in itertools.chain(gm.parameters(), gm.buffers())
|
||||
]
|
||||
|
||||
@ -2626,13 +2632,16 @@ def get_items_from_dict(obj: dict[K, V]) -> Iterable[tuple[K, Union[V, Any]]]:
|
||||
if istype(obj, (dict, OrderedDict)):
|
||||
return obj.items()
|
||||
elif isinstance(obj, OrderedDict):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return [(k, OrderedDict.__getitem__(obj, k)) for k in OrderedDict.keys(obj)]
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)]
|
||||
|
||||
|
||||
def nn_module_new(cls: Any) -> Any:
|
||||
obj = object_new(cls)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
torch.nn.Module.__init__(obj)
|
||||
return obj
|
||||
|
||||
@ -2679,6 +2688,7 @@ def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any:
|
||||
dict_class = dict
|
||||
if isinstance(d, OrderedDict):
|
||||
dict_class = OrderedDict
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return next(itertools.islice(dict_class.keys(d), n, n + 1))
|
||||
|
||||
|
||||
@ -3222,8 +3232,10 @@ def format_func_info(code: CodeType) -> str:
|
||||
@contextlib.contextmanager
|
||||
def disable_cache_limit() -> Generator[None, None, None]:
|
||||
prior = config.recompile_limit
|
||||
# pyrefly: ignore # bad-assignment
|
||||
config.recompile_limit = sys.maxsize
|
||||
prior_acc_limit = config.accumulated_recompile_limit
|
||||
# pyrefly: ignore # bad-assignment
|
||||
config.accumulated_recompile_limit = sys.maxsize
|
||||
|
||||
try:
|
||||
@ -3958,6 +3970,7 @@ class numpy_operator_wrapper(Generic[_P, R]):
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any:
|
||||
assert not kwargs
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
args = (
|
||||
tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
|
||||
)
|
||||
@ -4157,6 +4170,7 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]:
|
||||
# (x) + (y)
|
||||
# ~~^~~~~~~
|
||||
while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
|
||||
# pyrefly: ignore # unbound-name
|
||||
if ch in "\\#":
|
||||
cur_lineno, cur_col = nextline(cur_lineno, cur_col)
|
||||
else:
|
||||
@ -4507,6 +4521,7 @@ class GmWrapper(torch.nn.Module):
|
||||
self.unflatten_fn = unflatten_fn
|
||||
|
||||
def forward(self, *args: Any) -> Any:
|
||||
# pyrefly: ignore # annotation-mismatch
|
||||
args: list[Any] = list(args)
|
||||
return self.gm(*self.unflatten_fn(args))
|
||||
|
||||
|
@ -1028,6 +1028,7 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
def call_self_handler(tx: "InstructionTranslator", args, kwargs):
|
||||
try:
|
||||
# pyrefly: ignore # not-callable
|
||||
result = self_handler(tx, *args, **kwargs)
|
||||
if result is not None:
|
||||
return result
|
||||
@ -1035,6 +1036,7 @@ class BuiltinVariable(VariableTracker):
|
||||
# Check if binding is bad. inspect signature bind is expensive.
|
||||
# So check only when handler call fails.
|
||||
try:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
inspect.signature(self_handler).bind(tx, *args, **kwargs)
|
||||
except TypeError as e:
|
||||
has_constant_handler = obj.has_constant_handler(args, kwargs)
|
||||
@ -1087,6 +1089,7 @@ class BuiltinVariable(VariableTracker):
|
||||
hints=[*graph_break_hints.DYNAMO_BUG],
|
||||
from_exc=exc,
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
else:
|
||||
@ -1115,6 +1118,7 @@ class BuiltinVariable(VariableTracker):
|
||||
tx,
|
||||
args=list(map(ConstantVariable.create, exc.args)),
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
handlers.append(constant_fold_handler)
|
||||
@ -1437,6 +1441,7 @@ class BuiltinVariable(VariableTracker):
|
||||
resolved_fn = getattr(self.fn, name)
|
||||
if resolved_fn in dict_methods:
|
||||
if isinstance(args[0], variables.UserDefinedDictVariable):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
|
||||
elif isinstance(args[0], variables.ConstDictVariable):
|
||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||
@ -1445,6 +1450,7 @@ class BuiltinVariable(VariableTracker):
|
||||
resolved_fn = getattr(self.fn, name)
|
||||
if resolved_fn in set_methods:
|
||||
if isinstance(args[0], variables.UserDefinedSetVariable):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
|
||||
elif isinstance(args[0], variables.SetVariable):
|
||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||
@ -1533,10 +1539,12 @@ class BuiltinVariable(VariableTracker):
|
||||
if type(arg.value).__str__ is object.__str__:
|
||||
# Rely on the object str method
|
||||
try:
|
||||
# pyrefly: ignore # unbound-name
|
||||
return variables.ConstantVariable.create(value=str_method())
|
||||
except AttributeError:
|
||||
# Graph break
|
||||
return
|
||||
# pyrefly: ignore # unbound-name
|
||||
elif is_wrapper_or_member_descriptor(str_method):
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to a str() method implemented in C/C++",
|
||||
@ -1653,8 +1661,10 @@ class BuiltinVariable(VariableTracker):
|
||||
else:
|
||||
raw_b = b.raw_value
|
||||
if self.fn is max:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
raw_res = max(a.raw_value, raw_b)
|
||||
else:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
raw_res = min(a.raw_value, raw_b)
|
||||
|
||||
need_unwrap = any(
|
||||
@ -2106,6 +2116,7 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if isinstance(arg, variables.UserDefinedExceptionClassVariable):
|
||||
# pyrefly: ignore # unbound-name
|
||||
return ConstantVariable.create(isinstance(arg_type, isinstance_type))
|
||||
|
||||
isinstance_type_tuple: tuple[type, ...]
|
||||
@ -2138,8 +2149,10 @@ class BuiltinVariable(VariableTracker):
|
||||
# through it. This is a limitation of the current implementation.
|
||||
# Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
|
||||
# might not be a big issue and we trade off it for performance.
|
||||
# pyrefly: ignore # unbound-name
|
||||
val = issubclass(arg_type, isinstance_type_tuple)
|
||||
except TypeError:
|
||||
# pyrefly: ignore # unbound-name
|
||||
val = arg_type in isinstance_type_tuple
|
||||
return variables.ConstantVariable.create(val)
|
||||
|
||||
@ -2161,6 +2174,7 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
# WARNING: This might run arbitrary user code `__subclasscheck__`.
|
||||
# See the comment in call_isinstance above.
|
||||
# pyrefly: ignore # unbound-name
|
||||
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
|
||||
|
||||
def call_super(self, tx: "InstructionTranslator", a, b):
|
||||
@ -2206,7 +2220,9 @@ class BuiltinVariable(VariableTracker):
|
||||
value = getattr(self.fn, name)
|
||||
except AttributeError:
|
||||
raise_observed_exception(AttributeError, tx)
|
||||
# pyrefly: ignore # unbound-name
|
||||
if not callable(value):
|
||||
# pyrefly: ignore # unbound-name
|
||||
return VariableTracker.build(tx, value, source)
|
||||
return variables.GetAttrVariable(self, name, source=source)
|
||||
|
||||
|
@ -651,6 +651,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
def handle_use_deterministic_algorithms(
|
||||
self, tx: "InstructionTranslator", mode, warn_only=False
|
||||
):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if warn_only and warn_only.as_python_constant():
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)",
|
||||
@ -1035,6 +1036,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
else:
|
||||
raise torch._dynamo.exc.Unsupported("branch not supported")
|
||||
return variables.ConstantVariable.create(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
torch.fx.experimental.symbolic_shapes.guard_scalar(val)
|
||||
)
|
||||
|
||||
@ -1081,6 +1083,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
return
|
||||
|
||||
return variables.ConstantVariable.create(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
torch.fx.experimental.symbolic_shapes.has_static_value(val)
|
||||
)
|
||||
|
||||
@ -1212,13 +1215,17 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
)
|
||||
|
||||
# need to guard only on no-arg get_device_module
|
||||
# pyrefly: ignore # unbound-name
|
||||
if device is None:
|
||||
source = CallFunctionNoArgsSource(self.source)
|
||||
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
|
||||
# assumes `module` is in the form `torch.xyz`
|
||||
new_source = AttrSource(
|
||||
TorchSource(), module.__name__.rsplit(".", maxsplit=1)[-1]
|
||||
TorchSource(),
|
||||
# pyrefly: ignore # unbound-name
|
||||
module.__name__.rsplit(".", maxsplit=1)[-1],
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
return VariableTracker.build(tx, module, new_source)
|
||||
|
||||
@register(torch.set_default_device)
|
||||
@ -1373,9 +1380,12 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
f"{fn.__name__}_spec", f_spec
|
||||
)
|
||||
input_spec_proxy = tx.output.register_static_attr_and_return_proxy(
|
||||
fn.__name__ + "_input_spec", input_spec
|
||||
fn.__name__ + "_input_spec",
|
||||
# pyrefly: ignore # unbound-name
|
||||
input_spec,
|
||||
)
|
||||
f_spec_proxy.node.type = type(f_spec)
|
||||
# pyrefly: ignore # unbound-name
|
||||
input_spec_proxy.node.type = type(input_spec)
|
||||
all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args)
|
||||
|
||||
@ -1716,6 +1726,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
||||
)
|
||||
|
||||
# this results in cleaner graphs, but only works for inputs
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if data.source:
|
||||
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
|
||||
|
||||
@ -1734,7 +1745,9 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
||||
|
||||
# TODO[@lucaskabela]: Remove the behavior below since it is deprecated
|
||||
if isinstance(
|
||||
data, TensorWithTFOverrideVariable
|
||||
data,
|
||||
TensorWithTFOverrideVariable,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
) or is_traceable_wrapper_subclass_type(data.class_type):
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass",
|
||||
@ -1757,8 +1770,11 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
||||
)
|
||||
|
||||
try:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
|
||||
# pyrefly: ignore # missing-attribute
|
||||
dtype = data.var_getattr(tx, "dtype").as_python_constant()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
device = data.var_getattr(tx, "device").as_python_constant()
|
||||
except NotImplementedError as e:
|
||||
unimplemented_v2(
|
||||
@ -1773,9 +1789,13 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
||||
)
|
||||
|
||||
placeholder = tx.output.synthetic_graph_input(
|
||||
new_parameter_placeholder, [shape, dtype, device, requires_grad]
|
||||
new_parameter_placeholder,
|
||||
# pyrefly: ignore # unbound-name
|
||||
[shape, dtype, device, requires_grad],
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if data.requires_grad:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
data = data.call_method(tx, "detach", [], {})
|
||||
|
||||
from .builder import wrap_fx_proxy
|
||||
@ -1785,6 +1805,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
tracable_create_parameter,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
(data.as_proxy(), placeholder.as_proxy()),
|
||||
{},
|
||||
),
|
||||
|
@ -646,7 +646,6 @@ def update_schema():
|
||||
assert thrift_content[1].startswith("// checksum<<")
|
||||
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
from yaml import load, Loader
|
||||
|
||||
dst = load(content, Loader=Loader)
|
||||
|
@ -34,7 +34,7 @@ from torch.fx._pytree import (
|
||||
_deregister_pytree_flatten_spec,
|
||||
register_pytree_flatten_spec,
|
||||
)
|
||||
from torch.utils._pytree import ( # pyrefly: ignore # deprecated
|
||||
from torch.utils._pytree import (
|
||||
_deregister_pytree_node,
|
||||
_register_pytree_node,
|
||||
Context,
|
||||
|
@ -198,7 +198,7 @@ def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str:
|
||||
to a file. The yaml string can be loaded back into an operator profile
|
||||
structure using `read_profiles_from_yaml`.
|
||||
"""
|
||||
# pyrefly: ignore # import-error
|
||||
|
||||
import yaml
|
||||
|
||||
from torch._export.serde.serialize import (
|
||||
@ -263,7 +263,7 @@ def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]:
|
||||
"""
|
||||
Reads the yaml saved by `save_op_profiles` and returns the operator profiles.
|
||||
"""
|
||||
# pyrefly: ignore # import-error
|
||||
|
||||
import yaml
|
||||
|
||||
from torch._export.serde.serialize import (
|
||||
|
@ -1,5 +1,5 @@
|
||||
from .core import dispatch
|
||||
from .dispatcher import ( # pyrefly: ignore # deprecated
|
||||
from .dispatcher import (
|
||||
Dispatcher,
|
||||
halt_ordering,
|
||||
MDNotImplementedError,
|
||||
|
@ -153,6 +153,7 @@ class CausalBias(torch.Tensor):
|
||||
diagonal=diagonal_offset,
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
|
||||
"""
|
||||
Materializes the causal bias into a tensor form.
|
||||
|
@ -84,6 +84,7 @@ _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor
|
||||
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
class FlexKernelOptions(TypedDict, total=False):
|
||||
"""Options for controlling the behavior of FlexAttention kernels.
|
||||
|
||||
@ -127,76 +128,93 @@ class FlexKernelOptions(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
# Performance tuning options
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
num_warps: NotRequired[int]
|
||||
"""Number of warps to use in the CUDA kernel. Higher values may improve performance
|
||||
but increase register pressure. Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
num_stages: NotRequired[int]
|
||||
"""Number of pipeline stages in the CUDA kernel. Higher values may improve performance
|
||||
but increase shared memory usage. Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
BLOCK_M: NotRequired[int]
|
||||
"""Thread block size for the sequence length dimension of Q in forward pass.
|
||||
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
BLOCK_N: NotRequired[int]
|
||||
"""Thread block size for the sequence length dimension of K/V in forward pass.
|
||||
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
|
||||
|
||||
# Backward-specific block sizes (when prefixed with 'bwd_')
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
BLOCK_M1: NotRequired[int]
|
||||
"""Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'.
|
||||
Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
BLOCK_N1: NotRequired[int]
|
||||
"""Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'.
|
||||
Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
BLOCK_M2: NotRequired[int]
|
||||
"""Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'.
|
||||
Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
BLOCK_N2: NotRequired[int]
|
||||
"""Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'.
|
||||
Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
PRESCALE_QK: NotRequired[bool]
|
||||
"""Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but
|
||||
may have more numerical error. Default: False."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
ROWS_GUARANTEED_SAFE: NotRequired[bool]
|
||||
"""If True, guarantees that at least one value in each row is not masked out.
|
||||
Allows skipping safety checks for better performance. Only set this if you are certain
|
||||
your mask guarantees this property. For example, causal attention is guaranteed safe
|
||||
because each query has at least 1 key-value to attend to. Default: False."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]
|
||||
"""If True, guarantees that all blocks in the mask are contiguous.
|
||||
Allows optimizing block traversal. For example, causal masks would satisfy this,
|
||||
but prefix_lm + sliding window would not. Default: False."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
WRITE_DQ: NotRequired[bool]
|
||||
"""Controls whether gradient scatters are done in the DQ iteration loop of the backward pass.
|
||||
Setting this to False will force this to happen in the DK loop which depending on your
|
||||
specific score_mod and mask_mod might be faster. Default: True."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
FORCE_USE_FLEX_ATTENTION: NotRequired[bool]
|
||||
"""If True, forces the use of the flex attention kernel instead of potentially using
|
||||
the more optimized flex-decoding kernel for short sequences. This can be a helpful
|
||||
option for debugging. Default: False."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
USE_TMA: NotRequired[bool]
|
||||
"""Whether to use Tensor Memory Accelerator (TMA) on supported hardware.
|
||||
This is experimental and may not work on all hardware, currently specific
|
||||
to NVIDIA GPUs Hopper+. Default: False."""
|
||||
|
||||
# ROCm-specific options
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
kpack: NotRequired[int]
|
||||
"""ROCm-specific kernel packing parameter."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
matrix_instr_nonkdim: NotRequired[int]
|
||||
"""ROCm-specific matrix instruction non-K dimension."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
waves_per_eu: NotRequired[int]
|
||||
"""ROCm-specific waves per execution unit."""
|
||||
|
||||
@ -581,6 +599,7 @@ class BlockMask:
|
||||
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
|
||||
seq_lengths = (self.seq_lengths,) # type: ignore[assignment]
|
||||
|
||||
# pyrefly: ignore # not-iterable
|
||||
return (
|
||||
*seq_lengths,
|
||||
self.kv_num_blocks,
|
||||
@ -753,6 +772,7 @@ class BlockMask:
|
||||
partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
|
||||
if self.full_kv_num_blocks is not None:
|
||||
assert self.full_kv_indices is not None
|
||||
# pyrefly: ignore # bad-return
|
||||
return partial_dense | _ordered_to_dense(
|
||||
self.full_kv_num_blocks, self.full_kv_indices
|
||||
)
|
||||
|
@ -78,6 +78,7 @@ class ModuleWrapper(nn.Module):
|
||||
|
||||
# nn.Module defines training as a boolean
|
||||
@property # type: ignore[override]
|
||||
# pyrefly: ignore # bad-override
|
||||
def training(self):
|
||||
return self.cpp_module.training
|
||||
|
||||
|
@ -1266,6 +1266,7 @@ def adaptive_max_pool2d_with_indices(
|
||||
output_size,
|
||||
return_indices=return_indices,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_size = _list_with_default(output_size, input.size())
|
||||
return torch._C._nn.adaptive_max_pool2d(input, output_size)
|
||||
|
||||
@ -1323,6 +1324,7 @@ def adaptive_max_pool3d_with_indices(
|
||||
output_size,
|
||||
return_indices=return_indices,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_size = _list_with_default(output_size, input.size())
|
||||
return torch._C._nn.adaptive_max_pool3d(input, output_size)
|
||||
|
||||
@ -1381,6 +1383,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> T
|
||||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_output_size = _list_with_default(output_size, input.size())
|
||||
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
|
||||
|
||||
@ -1396,6 +1399,7 @@ def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> T
|
||||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_output_size = _list_with_default(output_size, input.size())
|
||||
return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
|
||||
|
||||
@ -2431,6 +2435,7 @@ def _no_grad_embedding_renorm_(
|
||||
input: Tensor,
|
||||
max_norm: float,
|
||||
norm_type: float,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
|
||||
|
||||
@ -2684,6 +2689,7 @@ def embedding_bag(
|
||||
|
||||
if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested:
|
||||
include_last_offset = True
|
||||
# pyrefly: ignore # missing-attribute
|
||||
offsets = input.offsets()
|
||||
input = input.values().reshape(-1)
|
||||
if per_sample_weights is not None:
|
||||
@ -2818,6 +2824,7 @@ def batch_norm(
|
||||
eps=eps,
|
||||
)
|
||||
if training:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_verify_batch_size(input.size())
|
||||
|
||||
return torch.batch_norm(
|
||||
@ -2873,6 +2880,7 @@ def instance_norm(
|
||||
eps=eps,
|
||||
)
|
||||
if use_input_stats:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_verify_spatial_size(input.size())
|
||||
return torch.instance_norm(
|
||||
input,
|
||||
@ -2998,11 +3006,13 @@ def local_response_norm(
|
||||
div = input.mul(input)
|
||||
if dim == 3:
|
||||
div = div.unsqueeze(1)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
div = pad(div, (0, 0, size // 2, (size - 1) // 2))
|
||||
div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
|
||||
else:
|
||||
sizes = input.size()
|
||||
div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
|
||||
div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
|
||||
div = div.view(sizes)
|
||||
@ -3151,7 +3161,12 @@ def nll_loss(
|
||||
if size_average is not None or reduce is not None:
|
||||
reduction = _Reduction.legacy_get_string(size_average, reduce)
|
||||
return torch._C._nn.nll_loss_nd(
|
||||
input, target, weight, _Reduction.get_enum(reduction), ignore_index
|
||||
input,
|
||||
target,
|
||||
weight,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_Reduction.get_enum(reduction),
|
||||
ignore_index,
|
||||
)
|
||||
|
||||
|
||||
@ -3296,6 +3311,7 @@ def gaussian_nll_loss(
|
||||
var.clamp_(min=eps)
|
||||
|
||||
# Calculate the loss
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var)
|
||||
if full:
|
||||
loss += 0.5 * math.log(2 * math.pi)
|
||||
@ -3471,6 +3487,7 @@ def cross_entropy(
|
||||
input,
|
||||
target,
|
||||
weight,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_Reduction.get_enum(reduction),
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
@ -3535,6 +3552,7 @@ def binary_cross_entropy(
|
||||
new_size = _infer_size(target.size(), weight.size())
|
||||
weight = weight.expand(new_size)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
|
||||
|
||||
|
||||
@ -3663,11 +3681,18 @@ def smooth_l1_loss(
|
||||
|
||||
if beta == 0.0:
|
||||
return torch._C._nn.l1_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction)
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_Reduction.get_enum(reduction),
|
||||
)
|
||||
else:
|
||||
return torch._C._nn.smooth_l1_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction), beta
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_Reduction.get_enum(reduction),
|
||||
beta,
|
||||
)
|
||||
|
||||
|
||||
@ -3725,7 +3750,11 @@ def huber_loss(
|
||||
if weight is None:
|
||||
# Use the optimized C++ backend for standard Huber loss
|
||||
return torch._C._nn.huber_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction), delta
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_Reduction.get_enum(reduction),
|
||||
delta,
|
||||
)
|
||||
else:
|
||||
if weight.size() != input.size():
|
||||
@ -3733,7 +3762,11 @@ def huber_loss(
|
||||
|
||||
# Calculate the unweighted loss first
|
||||
unweighted_loss = torch._C._nn.huber_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum("none"), delta
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_Reduction.get_enum("none"),
|
||||
delta,
|
||||
)
|
||||
|
||||
# Apply weight to the unweighted loss
|
||||
@ -3820,7 +3853,10 @@ def l1_loss(
|
||||
)
|
||||
else:
|
||||
return torch._C._nn.l1_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction)
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_Reduction.get_enum(reduction),
|
||||
)
|
||||
|
||||
|
||||
@ -3895,7 +3931,10 @@ def mse_loss(
|
||||
)
|
||||
else:
|
||||
return torch._C._nn.mse_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction)
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_Reduction.get_enum(reduction),
|
||||
)
|
||||
|
||||
|
||||
@ -4032,6 +4071,7 @@ def multilabel_margin_loss(
|
||||
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
|
||||
else:
|
||||
reduction_enum = _Reduction.get_enum(reduction)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
|
||||
|
||||
|
||||
@ -4073,6 +4113,7 @@ def soft_margin_loss(
|
||||
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
|
||||
else:
|
||||
reduction_enum = _Reduction.get_enum(reduction)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
|
||||
|
||||
|
||||
@ -4237,7 +4278,13 @@ def multi_margin_loss(
|
||||
raise ValueError("weight must be one-dimensional")
|
||||
|
||||
return torch._C._nn.multi_margin_loss(
|
||||
input, target, p, margin, weight, reduction_enum
|
||||
input,
|
||||
target,
|
||||
p,
|
||||
margin,
|
||||
weight,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
reduction_enum,
|
||||
)
|
||||
|
||||
|
||||
@ -4383,6 +4430,7 @@ def upsample( # noqa: F811
|
||||
scale_factor: Optional[float] = None,
|
||||
mode: str = "nearest",
|
||||
align_corners: Optional[bool] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor: # noqa: B950
|
||||
pass
|
||||
|
||||
@ -4394,6 +4442,7 @@ def upsample( # noqa: F811
|
||||
scale_factor: Optional[float] = None,
|
||||
mode: str = "nearest",
|
||||
align_corners: Optional[bool] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor: # noqa: B950
|
||||
pass
|
||||
|
||||
@ -4496,6 +4545,7 @@ def interpolate( # noqa: F811
|
||||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
antialias: bool = False,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor: # noqa: B950
|
||||
pass
|
||||
|
||||
@ -4509,6 +4559,7 @@ def interpolate( # noqa: F811
|
||||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
antialias: bool = False,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor: # noqa: B950
|
||||
pass
|
||||
|
||||
@ -4522,6 +4573,7 @@ def interpolate( # noqa: F811
|
||||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
antialias: bool = False,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor: # noqa: B950
|
||||
pass
|
||||
|
||||
@ -4535,6 +4587,7 @@ def interpolate( # noqa: F811
|
||||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
antialias: bool = False,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
@ -4709,6 +4762,7 @@ def interpolate( # noqa: F811
|
||||
(
|
||||
torch.floor(
|
||||
(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
input.size(i + 2).float()
|
||||
* torch.tensor(scale_factors[i], dtype=torch.float32)
|
||||
).float()
|
||||
@ -4733,21 +4787,28 @@ def interpolate( # noqa: F811
|
||||
)
|
||||
|
||||
if input.dim() == 3 and mode == "nearest":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
|
||||
if input.dim() == 4 and mode == "nearest":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
|
||||
if input.dim() == 5 and mode == "nearest":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
|
||||
|
||||
if input.dim() == 3 and mode == "nearest-exact":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
|
||||
if input.dim() == 4 and mode == "nearest-exact":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
|
||||
if input.dim() == 5 and mode == "nearest-exact":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
|
||||
|
||||
if input.dim() == 3 and mode == "area":
|
||||
assert output_size is not None
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return adaptive_avg_pool1d(input, output_size)
|
||||
if input.dim() == 4 and mode == "area":
|
||||
assert output_size is not None
|
||||
@ -4759,13 +4820,21 @@ def interpolate( # noqa: F811
|
||||
if input.dim() == 3 and mode == "linear":
|
||||
assert align_corners is not None
|
||||
return torch._C._nn.upsample_linear1d(
|
||||
input, output_size, align_corners, scale_factors
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
)
|
||||
if input.dim() == 4 and mode == "bilinear":
|
||||
assert align_corners is not None
|
||||
if antialias:
|
||||
return torch._C._nn._upsample_bilinear2d_aa(
|
||||
input, output_size, align_corners, scale_factors
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
)
|
||||
# Two levels are necessary to prevent TorchScript from touching
|
||||
# are_deterministic_algorithms_enabled.
|
||||
@ -4778,7 +4847,11 @@ def interpolate( # noqa: F811
|
||||
"torch._decomp.decompositions"
|
||||
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
|
||||
return torch._C._nn.upsample_bilinear2d(
|
||||
input, output_size, align_corners, scale_factors
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
)
|
||||
if input.dim() == 5 and mode == "trilinear":
|
||||
assert align_corners is not None
|
||||
@ -4793,16 +4866,28 @@ def interpolate( # noqa: F811
|
||||
"torch._decomp.decompositions"
|
||||
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
|
||||
return torch._C._nn.upsample_trilinear3d(
|
||||
input, output_size, align_corners, scale_factors
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
)
|
||||
if input.dim() == 4 and mode == "bicubic":
|
||||
assert align_corners is not None
|
||||
if antialias:
|
||||
return torch._C._nn._upsample_bicubic2d_aa(
|
||||
input, output_size, align_corners, scale_factors
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
)
|
||||
return torch._C._nn.upsample_bicubic2d(
|
||||
input, output_size, align_corners, scale_factors
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
)
|
||||
|
||||
if input.dim() == 3 and mode == "bilinear":
|
||||
@ -4834,6 +4919,7 @@ def upsample_nearest( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[int] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
@ -4843,6 +4929,7 @@ def upsample_nearest( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
@ -4884,6 +4971,7 @@ def upsample_bilinear( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[int] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
@ -4893,6 +4981,7 @@ def upsample_bilinear( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
@ -4902,6 +4991,7 @@ def upsample_bilinear( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[int] = None,
|
||||
scale_factor: Optional[list[float]] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
@ -4911,6 +5001,7 @@ def upsample_bilinear( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[list[float]] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
@ -5717,6 +5808,7 @@ def _in_projection_packed(
|
||||
.squeeze(-2)
|
||||
.contiguous()
|
||||
)
|
||||
# pyrefly: ignore # bad-return
|
||||
return proj[0], proj[1], proj[2]
|
||||
else:
|
||||
# encoder-decoder attention
|
||||
@ -5735,6 +5827,7 @@ def _in_projection_packed(
|
||||
.squeeze(-2)
|
||||
.contiguous()
|
||||
)
|
||||
# pyrefly: ignore # bad-return
|
||||
return (q_proj, kv_proj[0], kv_proj[1])
|
||||
else:
|
||||
w_q, w_k, w_v = w.chunk(3)
|
||||
@ -5742,6 +5835,7 @@ def _in_projection_packed(
|
||||
b_q = b_k = b_v = None
|
||||
else:
|
||||
b_q, b_k, b_v = b.chunk(3)
|
||||
# pyrefly: ignore # bad-return
|
||||
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
||||
|
||||
|
||||
@ -6372,8 +6466,10 @@ def multi_head_attention_forward(
|
||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
attn_mask = pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
key_padding_mask = pad(key_padding_mask, (0, 1))
|
||||
else:
|
||||
assert bias_k is None
|
||||
@ -6382,8 +6478,10 @@ def multi_head_attention_forward(
|
||||
#
|
||||
# reshape q, k, v for multihead attention and make them batch first
|
||||
#
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
if static_k is None:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
@ -6395,6 +6493,7 @@ def multi_head_attention_forward(
|
||||
)
|
||||
k = static_k
|
||||
if static_v is None:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
@ -6410,14 +6509,20 @@ def multi_head_attention_forward(
|
||||
if add_zero_attn:
|
||||
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||
k = torch.cat(
|
||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
|
||||
dim=1,
|
||||
)
|
||||
v = torch.cat(
|
||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
|
||||
dim=1,
|
||||
)
|
||||
if attn_mask is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
attn_mask = pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
key_padding_mask = pad(key_padding_mask, (0, 1))
|
||||
|
||||
# update source sequence length after adjustments
|
||||
@ -6467,6 +6572,7 @@ def multi_head_attention_forward(
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
|
||||
attn_output = (
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
@ -6493,13 +6599,16 @@ def multi_head_attention_forward(
|
||||
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
||||
|
||||
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
k = k.view(bsz, num_heads, src_len, head_dim)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
attn_output = (
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
)
|
||||
|
||||
|
@ -500,6 +500,7 @@ def xavier_normal_(
|
||||
|
||||
|
||||
def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
mode = mode.lower()
|
||||
valid_modes = ["fan_in", "fan_out"]
|
||||
if mode not in valid_modes:
|
||||
|
@ -6,6 +6,7 @@ from torch.autograd.function import Function
|
||||
|
||||
class SyncBatchNorm(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
@ -210,6 +211,7 @@ class SyncBatchNorm(Function):
|
||||
|
||||
class CrossMapLRN2d(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
|
||||
ctx.size = size
|
||||
ctx.alpha = alpha
|
||||
@ -265,6 +267,7 @@ class CrossMapLRN2d(Function):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
input, output = ctx.saved_tensors
|
||||
grad_input = grad_output.new()
|
||||
@ -306,6 +309,7 @@ class CrossMapLRN2d(Function):
|
||||
|
||||
class BackwardHookFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, *args):
|
||||
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
|
||||
return args
|
||||
|
@ -72,6 +72,7 @@ class _NormBase(Module):
|
||||
torch.tensor(
|
||||
0,
|
||||
dtype=torch.long,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
||||
),
|
||||
)
|
||||
@ -221,6 +222,7 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
super().__init__(
|
||||
# affine and track_running_stats are hardcoded to False to
|
||||
# avoid creating tensors that will soon be overwritten.
|
||||
@ -234,22 +236,29 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
if self.affine:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
if self.track_running_stats:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.running_mean = UninitializedBuffer(**factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.running_var = UninitializedBuffer(**factory_kwargs)
|
||||
self.num_batches_tracked = torch.tensor(
|
||||
0,
|
||||
dtype=torch.long,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
||||
)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if not self.has_uninitialized_params() and self.num_features != 0:
|
||||
super().reset_parameters()
|
||||
|
||||
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if self.has_uninitialized_params():
|
||||
self.num_features = input.shape[1]
|
||||
if self.affine:
|
||||
|
@ -109,6 +109,7 @@ class Sequential(Module):
|
||||
def __init__(self, *args: Module) -> None: ...
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
|
||||
|
||||
def __init__(self, *args):
|
||||
@ -472,6 +473,7 @@ class ModuleList(Module):
|
||||
return self
|
||||
|
||||
def pop(self, key: Union[int, slice]) -> Module:
|
||||
# pyrefly: ignore # index-error
|
||||
v = self[key]
|
||||
del self[key]
|
||||
return v
|
||||
@ -623,9 +625,11 @@ class ModuleDict(Module):
|
||||
"ModuleDict update sequence element "
|
||||
"#" + str(j) + " should be Iterable; is" + type(m).__name__
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if not len(m) == 2:
|
||||
raise ValueError(
|
||||
"ModuleDict update sequence element "
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
|
||||
)
|
||||
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
|
||||
@ -684,6 +688,7 @@ class ParameterList(Module):
|
||||
def __getitem__(self, idx: int) -> Any: ...
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
def __getitem__(self: T, idx: slice) -> T: ...
|
||||
|
||||
def __getitem__(self, idx):
|
||||
@ -769,9 +774,11 @@ class ParameterList(Module):
|
||||
size_str,
|
||||
device_str,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
child_lines.append(" (" + str(k) + "): " + parastr)
|
||||
else:
|
||||
child_lines.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
" (" + str(k) + "): Object of type: " + type(p).__name__
|
||||
)
|
||||
|
||||
@ -979,9 +986,11 @@ class ParameterDict(Module):
|
||||
"ParameterDict update sequence element "
|
||||
"#" + str(j) + " should be Iterable; is" + type(p).__name__
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if not len(p) == 2:
|
||||
raise ValueError(
|
||||
"ParameterDict update sequence element "
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
|
||||
)
|
||||
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
|
||||
@ -1002,9 +1011,11 @@ class ParameterDict(Module):
|
||||
size_str,
|
||||
device_str,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
child_lines.append(" (" + str(k) + "): " + parastr)
|
||||
else:
|
||||
child_lines.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
" (" + str(k) + "): Object of type: " + type(p).__name__
|
||||
)
|
||||
tmpstr = "\n".join(child_lines)
|
||||
|
@ -363,6 +363,7 @@ class Conv1d(_ConvNd):
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return F.conv1d(
|
||||
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
@ -540,6 +541,7 @@ class Conv2d(_ConvNd):
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return F.conv2d(
|
||||
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
@ -709,6 +711,7 @@ class Conv3d(_ConvNd):
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return F.conv3d(
|
||||
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
@ -1494,6 +1497,7 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
@ -1508,9 +1512,11 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
||||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
@ -1563,6 +1569,7 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
@ -1577,9 +1584,11 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
||||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
@ -1633,6 +1642,7 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
@ -1647,9 +1657,11 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
||||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
@ -1701,6 +1713,7 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
@ -1716,9 +1729,11 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
|
||||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
@ -1770,6 +1785,7 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
@ -1785,9 +1801,11 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
|
||||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
@ -1839,6 +1857,7 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
@ -1854,9 +1873,11 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
|
||||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
|
@ -172,7 +172,9 @@ class LazyModuleMixin:
|
||||
def __init__(self: _LazyProtocol, *args, **kwargs):
|
||||
# Mypy doesn't like this super call in a mixin
|
||||
super().__init__(*args, **kwargs) # type: ignore[misc]
|
||||
# pyrefly: ignore # read-only
|
||||
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
|
||||
# pyrefly: ignore # read-only
|
||||
self._initialize_hook = self.register_forward_pre_hook(
|
||||
self._infer_parameters, with_kwargs=True
|
||||
)
|
||||
|
@ -286,6 +286,7 @@ class LazyLinear(LazyModuleMixin, Linear):
|
||||
"""
|
||||
|
||||
cls_to_become = Linear # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-override
|
||||
weight: UninitializedParameter
|
||||
bias: UninitializedParameter # type: ignore[assignment]
|
||||
|
||||
@ -295,16 +296,20 @@ class LazyLinear(LazyModuleMixin, Linear):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# bias is hardcoded to False to avoid creating tensor
|
||||
# that will soon be overwritten.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
super().__init__(0, 0, False)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
"""
|
||||
Resets parameters based on their initialization used in ``__init__``.
|
||||
"""
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if not self.has_uninitialized_params() and self.in_features != 0:
|
||||
super().reset_parameters()
|
||||
|
||||
@ -312,6 +317,7 @@ class LazyLinear(LazyModuleMixin, Linear):
|
||||
"""
|
||||
Infers ``in_features`` based on ``input`` and initializes parameters.
|
||||
"""
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if self.has_uninitialized_params():
|
||||
with torch.no_grad():
|
||||
self.in_features = input.shape[-1]
|
||||
|
@ -38,11 +38,13 @@ T = TypeVar("T", bound="Module")
|
||||
|
||||
|
||||
class _IncompatibleKeys(
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
||||
):
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if not self.missing_keys and not self.unexpected_keys:
|
||||
return "<All keys matched successfully>"
|
||||
return super().__repr__()
|
||||
@ -91,6 +93,7 @@ class _WrappedHook:
|
||||
def __getstate__(self) -> dict:
|
||||
result = {"hook": self.hook, "with_module": self.with_module}
|
||||
if self.with_module:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
result["module"] = self.module()
|
||||
|
||||
return result
|
||||
@ -976,7 +979,9 @@ class Module:
|
||||
# Decrement use count of the gradient by setting to None
|
||||
param.grad = None
|
||||
param_applied = torch.nn.Parameter(
|
||||
param_applied, requires_grad=param.requires_grad
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
param_applied,
|
||||
requires_grad=param.requires_grad,
|
||||
)
|
||||
torch.utils.swap_tensors(param, param_applied)
|
||||
except Exception as e:
|
||||
@ -987,11 +992,13 @@ class Module:
|
||||
) from e
|
||||
out_param = param
|
||||
elif p_should_use_set_data:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
param.data = param_applied
|
||||
out_param = param
|
||||
else:
|
||||
assert isinstance(param, Parameter)
|
||||
assert param.is_leaf
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
out_param = Parameter(param_applied, param.requires_grad)
|
||||
self._parameters[key] = out_param
|
||||
|
||||
@ -2253,6 +2260,7 @@ class Module:
|
||||
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
destination._metadata = OrderedDict()
|
||||
|
||||
local_metadata = dict(version=self._version)
|
||||
@ -2402,7 +2410,9 @@ class Module:
|
||||
if k not in self._non_persistent_buffers_set
|
||||
}
|
||||
local_name_params = itertools.chain(
|
||||
self._parameters.items(), persistent_buffers.items()
|
||||
self._parameters.items(),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
persistent_buffers.items(),
|
||||
)
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
|
@ -84,6 +84,7 @@ class CircularPad1d(_CircularPadNd):
|
||||
[5., 6., 7., 4., 5., 6., 7., 4.]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t) -> None:
|
||||
@ -144,6 +145,7 @@ class CircularPad2d(_CircularPadNd):
|
||||
[8., 6., 7., 8., 6.]]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t) -> None:
|
||||
@ -194,6 +196,7 @@ class CircularPad3d(_CircularPadNd):
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t) -> None:
|
||||
@ -265,6 +268,7 @@ class ConstantPad1d(_ConstantPadNd):
|
||||
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t, value: float) -> None:
|
||||
@ -316,6 +320,7 @@ class ConstantPad2d(_ConstantPadNd):
|
||||
"""
|
||||
|
||||
__constants__ = ["padding", "value"]
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t, value: float) -> None:
|
||||
@ -356,6 +361,7 @@ class ConstantPad3d(_ConstantPadNd):
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t, value: float) -> None:
|
||||
@ -409,6 +415,7 @@ class ReflectionPad1d(_ReflectionPadNd):
|
||||
[7., 6., 5., 4., 5., 6., 7., 6.]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t) -> None:
|
||||
@ -462,6 +469,7 @@ class ReflectionPad2d(_ReflectionPadNd):
|
||||
[7., 6., 7., 8., 7.]]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t) -> None:
|
||||
@ -517,6 +525,7 @@ class ReflectionPad3d(_ReflectionPadNd):
|
||||
[1., 0., 1., 0.]]]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t) -> None:
|
||||
@ -570,6 +579,7 @@ class ReplicationPad1d(_ReplicationPadNd):
|
||||
[4., 4., 4., 4., 5., 6., 7., 7.]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t) -> None:
|
||||
@ -623,6 +633,7 @@ class ReplicationPad2d(_ReplicationPadNd):
|
||||
[6., 6., 7., 8., 8.]]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t) -> None:
|
||||
@ -665,6 +676,7 @@ class ReplicationPad3d(_ReplicationPadNd):
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t) -> None:
|
||||
|
@ -111,6 +111,7 @@ class RNNBase(Module):
|
||||
|
||||
if (
|
||||
not isinstance(dropout, numbers.Number)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
or not 0 <= dropout <= 1
|
||||
or isinstance(dropout, bool)
|
||||
):
|
||||
@ -119,6 +120,7 @@ class RNNBase(Module):
|
||||
"representing the probability of an element being "
|
||||
"zeroed"
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if dropout > 0 and num_layers == 1:
|
||||
warnings.warn(
|
||||
"dropout option adds dropout after all but last "
|
||||
@ -639,15 +641,22 @@ class RNN(RNNBase):
|
||||
|
||||
@overload
|
||||
@torch._jit_internal._overload_method # noqa: F811
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
self, input: Tensor, hx: Optional[Tensor] = None
|
||||
self,
|
||||
input: Tensor,
|
||||
hx: Optional[Tensor] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
@torch._jit_internal._overload_method # noqa: F811
|
||||
def forward(
|
||||
self, input: PackedSequence, hx: Optional[Tensor] = None
|
||||
self,
|
||||
input: PackedSequence,
|
||||
hx: Optional[Tensor] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> tuple[PackedSequence, Tensor]:
|
||||
pass
|
||||
|
||||
@ -772,7 +781,11 @@ class RNN(RNNBase):
|
||||
|
||||
if isinstance(orig_input, PackedSequence):
|
||||
output_packed = PackedSequence(
|
||||
output, batch_sizes, sorted_indices, unsorted_indices
|
||||
output,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
batch_sizes,
|
||||
sorted_indices,
|
||||
unsorted_indices,
|
||||
)
|
||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
@ -996,6 +1009,7 @@ class LSTM(RNNBase):
|
||||
|
||||
# In the future, we should prevent mypy from applying contravariance rules here.
|
||||
# See torch/nn/modules/module.py::_forward_unimplemented
|
||||
# pyrefly: ignore # bad-override
|
||||
def check_forward_args(
|
||||
self,
|
||||
input: Tensor,
|
||||
@ -1029,8 +1043,12 @@ class LSTM(RNNBase):
|
||||
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
|
||||
@overload # type: ignore[override]
|
||||
@torch._jit_internal._overload_method # noqa: F811
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
|
||||
self,
|
||||
input: Tensor,
|
||||
hx: Optional[tuple[Tensor, Tensor]] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811
|
||||
pass
|
||||
|
||||
@ -1038,7 +1056,10 @@ class LSTM(RNNBase):
|
||||
@overload
|
||||
@torch._jit_internal._overload_method # noqa: F811
|
||||
def forward(
|
||||
self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None
|
||||
self,
|
||||
input: PackedSequence,
|
||||
hx: Optional[tuple[Tensor, Tensor]] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811
|
||||
pass
|
||||
|
||||
@ -1152,7 +1173,11 @@ class LSTM(RNNBase):
|
||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||
if isinstance(orig_input, PackedSequence):
|
||||
output_packed = PackedSequence(
|
||||
output, batch_sizes, sorted_indices, unsorted_indices
|
||||
output,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
batch_sizes,
|
||||
sorted_indices,
|
||||
unsorted_indices,
|
||||
)
|
||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||
else:
|
||||
@ -1318,15 +1343,22 @@ class GRU(RNNBase):
|
||||
|
||||
@overload # type: ignore[override]
|
||||
@torch._jit_internal._overload_method # noqa: F811
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
self, input: Tensor, hx: Optional[Tensor] = None
|
||||
self,
|
||||
input: Tensor,
|
||||
hx: Optional[Tensor] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> tuple[Tensor, Tensor]: # noqa: F811
|
||||
pass
|
||||
|
||||
@overload
|
||||
@torch._jit_internal._overload_method # noqa: F811
|
||||
def forward(
|
||||
self, input: PackedSequence, hx: Optional[Tensor] = None
|
||||
self,
|
||||
input: PackedSequence,
|
||||
hx: Optional[Tensor] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> tuple[PackedSequence, Tensor]: # noqa: F811
|
||||
pass
|
||||
|
||||
@ -1420,7 +1452,11 @@ class GRU(RNNBase):
|
||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||
if isinstance(orig_input, PackedSequence):
|
||||
output_packed = PackedSequence(
|
||||
output, batch_sizes, sorted_indices, unsorted_indices
|
||||
output,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
batch_sizes,
|
||||
sorted_indices,
|
||||
unsorted_indices,
|
||||
)
|
||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||
else:
|
||||
|
@ -135,7 +135,11 @@ class Transformer(Module):
|
||||
**factory_kwargs,
|
||||
)
|
||||
encoder_norm = LayerNorm(
|
||||
d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
|
||||
d_model,
|
||||
eps=layer_norm_eps,
|
||||
bias=bias,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.encoder = TransformerEncoder(
|
||||
encoder_layer, num_encoder_layers, encoder_norm
|
||||
@ -157,7 +161,11 @@ class Transformer(Module):
|
||||
**factory_kwargs,
|
||||
)
|
||||
decoder_norm = LayerNorm(
|
||||
d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
|
||||
d_model,
|
||||
eps=layer_norm_eps,
|
||||
bias=bias,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.decoder = TransformerDecoder(
|
||||
decoder_layer, num_decoder_layers, decoder_norm
|
||||
@ -760,7 +768,9 @@ class TransformerEncoderLayer(Module):
|
||||
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
||||
|
||||
self.norm_first = norm_first
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
self.dropout1 = Dropout(dropout)
|
||||
self.dropout2 = Dropout(dropout)
|
||||
@ -1052,8 +1062,11 @@ class TransformerDecoderLayer(Module):
|
||||
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
||||
|
||||
self.norm_first = norm_first
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
self.dropout1 = Dropout(dropout)
|
||||
self.dropout2 = Dropout(dropout)
|
||||
|
@ -36,6 +36,7 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
|
||||
import torch
|
||||
|
||||
if isinstance(out_size, (int, torch.SymInt)):
|
||||
# pyrefly: ignore # bad-return
|
||||
return out_size
|
||||
if len(defaults) <= len(out_size):
|
||||
raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")
|
||||
|
@ -43,6 +43,7 @@ def broadcast(tensor, devices=None, *, out=None):
|
||||
devices = [_get_device_index(d) for d in devices]
|
||||
return torch._C._broadcast(tensor, devices)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch._C._broadcast_out(tensor, out)
|
||||
|
||||
|
||||
@ -200,6 +201,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=
|
||||
"""
|
||||
tensor = _handle_complex(tensor)
|
||||
if out is None:
|
||||
# pyrefly: ignore # not-iterable
|
||||
devices = [_get_device_index(d) for d in devices]
|
||||
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
|
||||
else:
|
||||
|
@ -160,6 +160,7 @@ class DataParallel(Module, Generic[T]):
|
||||
self.module = module
|
||||
self.device_ids = [_get_device_index(x, True) for x in device_ids]
|
||||
self.output_device = _get_device_index(output_device, True)
|
||||
# pyrefly: ignore # read-only
|
||||
self.src_device_obj = torch.device(device_type, self.device_ids[0])
|
||||
|
||||
if device_type == "cuda":
|
||||
@ -173,6 +174,7 @@ class DataParallel(Module, Generic[T]):
|
||||
if not self.device_ids:
|
||||
return self.module(*inputs, **kwargs)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
for t in chain(self.module.parameters(), self.module.buffers()):
|
||||
if t.device != self.src_device_obj:
|
||||
raise RuntimeError(
|
||||
@ -259,8 +261,10 @@ def data_parallel(
|
||||
|
||||
device_ids = [_get_device_index(x, True) for x in device_ids]
|
||||
output_device = _get_device_index(output_device, True)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
src_device_obj = torch.device(device_type, device_ids[0])
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
for t in chain(module.parameters(), module.buffers()):
|
||||
if t.device != src_device_obj:
|
||||
raise RuntimeError(
|
||||
|
@ -241,6 +241,7 @@ class _BufferCommHook:
|
||||
# is completed.
|
||||
class _DDPSink(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, ddp_weakref, *inputs):
|
||||
# set_materialize_grads(False) will ensure that None gradients stay as
|
||||
# None and are not filled with zeros.
|
||||
@ -691,6 +692,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
elif process_group is None and device_mesh is None:
|
||||
self.process_group = _get_default_group()
|
||||
elif device_mesh is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.process_group = process_group
|
||||
else:
|
||||
if device_mesh.ndim != 1:
|
||||
@ -779,11 +781,13 @@ class DistributedDataParallel(Module, Joinable):
|
||||
self.device_ids = None
|
||||
self.output_device = None
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.device_ids = [_get_device_index(x, True) for x in device_ids]
|
||||
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.output_device = _get_device_index(output_device, True)
|
||||
|
||||
self.static_graph = False
|
||||
@ -933,6 +937,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# enabled.
|
||||
self._accum_grad_hooks: list[RemovableHandle] = []
|
||||
if self._use_python_reducer:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
torch._inductor.config._fuse_ddp_communication = True
|
||||
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
|
||||
# Directly adding this to the trace rule will disturb the users
|
||||
|
@ -56,12 +56,16 @@ def scatter(inputs, target_gpus, dim=0):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return Scatter.apply(target_gpus, None, dim, obj)
|
||||
if _is_namedtuple(obj):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return list(zip(*map(scatter_map, obj)))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return [list(i) for i in zip(*map(scatter_map, obj))]
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
|
||||
return [obj for _ in target_gpus]
|
||||
|
||||
@ -123,9 +127,12 @@ def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0)
|
||||
if isinstance(out, dict):
|
||||
if not all(len(out) == len(d) for d in outputs):
|
||||
raise ValueError("All dicts must have the same number of keys")
|
||||
# pyrefly: ignore # not-callable
|
||||
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
|
||||
if _is_namedtuple(out):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return type(out)._make(map(gather_map, zip(*outputs)))
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
return type(out)(map(gather_map, zip(*outputs)))
|
||||
|
||||
# Recursive function calls like this create reference cycles.
|
||||
|
@ -81,6 +81,7 @@ class Parameter(torch.Tensor, metaclass=_ParameterMeta):
|
||||
memo[id(self)] = result
|
||||
return result
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def __repr__(self):
|
||||
return "Parameter containing:\n" + super().__repr__()
|
||||
|
||||
@ -143,6 +144,7 @@ class UninitializedTensorMixin:
|
||||
if dtype is None:
|
||||
dtype = self.data.dtype
|
||||
self.data = torch.empty(shape, device=device, dtype=dtype)
|
||||
# pyrefly: ignore # bad-override, missing-attribute
|
||||
self.__class__ = self.cls_to_become
|
||||
|
||||
@property
|
||||
@ -166,6 +168,7 @@ class UninitializedTensorMixin:
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
# See Note [Don't serialize hooks]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return (self.__class__, (self.requires_grad,))
|
||||
|
||||
@classmethod
|
||||
@ -175,6 +178,7 @@ class UninitializedTensorMixin:
|
||||
if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper":
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
raise ValueError(
|
||||
f"Attempted to use an uninitialized parameter in {func}. "
|
||||
@ -216,6 +220,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
|
||||
def __new__(cls, requires_grad=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
data = torch.empty(0, **factory_kwargs)
|
||||
# pyrefly: ignore # bad-return
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
@ -261,7 +266,9 @@ class Buffer(torch.Tensor, metaclass=_BufferMeta):
|
||||
data = torch.empty(0)
|
||||
|
||||
t = data.detach().requires_grad_(data.requires_grad)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t.persistent = persistent
|
||||
# pyrefly: ignore # missing-attribute
|
||||
t._is_buffer = True
|
||||
return t
|
||||
|
||||
@ -292,6 +299,9 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
data = torch.empty(0, **factory_kwargs)
|
||||
ret = torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
ret.persistent = persistent
|
||||
# pyrefly: ignore # missing-attribute
|
||||
ret._is_buffer = True
|
||||
# pyrefly: ignore # bad-return
|
||||
return ret
|
||||
|
@ -1,5 +1,5 @@
|
||||
from . import parametrizations, parametrize, rnn, stateless
|
||||
from .clip_grad import (
|
||||
from .clip_grad import ( # pyrefly: ignore # deprecated
|
||||
_clip_grads_with_norm_ as clip_grads_with_norm_,
|
||||
_get_total_norm as get_total_norm,
|
||||
clip_grad_norm,
|
||||
|
@ -24,6 +24,7 @@ from .expanded_weights_utils import forward_helper
|
||||
@implements_per_sample_grads(F.conv3d)
|
||||
class ConvPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
ctx: Any,
|
||||
kwarg_names: list[str],
|
||||
@ -56,6 +57,7 @@ class ConvPerSampleGrad(torch.autograd.Function):
|
||||
f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}"
|
||||
)
|
||||
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
ctx.conv_fn = conv_fn
|
||||
|
||||
ctx.batch_size = orig_input.shape[0]
|
||||
|
@ -237,6 +237,7 @@ def conv_unfold_weight_grad_sample(
|
||||
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
|
||||
weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
|
||||
# rearrange the above tensor and extract diagonals.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
weight_grad_sample = weight_grad_sample.view(
|
||||
n,
|
||||
groups,
|
||||
|
@ -14,6 +14,7 @@ from .expanded_weights_utils import (
|
||||
@implements_per_sample_grads(F.embedding)
|
||||
class EmbeddingPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(
|
||||
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
@ -34,6 +35,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(
|
||||
ctx: Any, grad_output: torch.Tensor
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
|
@ -131,7 +131,9 @@ class ExpandedWeight(torch.Tensor):
|
||||
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
|
||||
decomp_opts = expanded_weights_rnn_decomps[func]
|
||||
use_input_variant = isinstance(
|
||||
args[2], list
|
||||
# pyrefly: ignore # index-error
|
||||
args[2],
|
||||
list,
|
||||
) # data variant uses a list here
|
||||
decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]
|
||||
|
||||
|
@ -8,6 +8,7 @@ from .expanded_weights_impl import ExpandedWeight
|
||||
|
||||
def is_batch_first(expanded_args_and_kwargs):
|
||||
batch_first = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for arg in expanded_args_and_kwargs:
|
||||
if not isinstance(arg, ExpandedWeight):
|
||||
continue
|
||||
|
@ -18,6 +18,7 @@ from .expanded_weights_utils import (
|
||||
@implements_per_sample_grads(F.group_norm)
|
||||
class GroupNormPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(
|
||||
kwarg_names, expanded_args_and_kwargs
|
||||
@ -46,6 +47,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
input, num_groups = ctx.input, ctx.num_groups
|
||||
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
||||
@ -94,7 +96,9 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
|
||||
set_grad_sample_if_exists(
|
||||
weight,
|
||||
lambda _: torch.einsum(
|
||||
"ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output
|
||||
"ni...->ni",
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
F.group_norm(input, num_groups, eps=eps) * grad_output,
|
||||
),
|
||||
)
|
||||
if hasattr(ctx, "bias"):
|
||||
|
@ -17,6 +17,7 @@ from .expanded_weights_utils import (
|
||||
@implements_per_sample_grads(F.instance_norm)
|
||||
class InstanceNormPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||
instance_norm = partial(torch.instance_norm, cudnn_enabled=True)
|
||||
expanded_args, expanded_kwargs = standard_kwargs(
|
||||
@ -36,6 +37,7 @@ class InstanceNormPerSampleGrad(torch.autograd.Function):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
|
||||
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
||||
|
@ -17,6 +17,7 @@ from .expanded_weights_utils import (
|
||||
@implements_per_sample_grads(F.layer_norm)
|
||||
class LayerNormPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(
|
||||
kwarg_names, expanded_args_and_kwargs
|
||||
@ -42,6 +43,7 @@ class LayerNormPerSampleGrad(torch.autograd.Function):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
def weight_per_sample_grad(weight):
|
||||
return sum_over_all_but_batch_and_last_n(
|
||||
|
@ -16,6 +16,7 @@ from .expanded_weights_utils import (
|
||||
@implements_per_sample_grads(F.linear)
|
||||
class LinearPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, _, __, *expanded_args_and_kwargs):
|
||||
if len(expanded_args_and_kwargs[0].shape) <= 1:
|
||||
raise RuntimeError(
|
||||
@ -35,6 +36,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.args
|
||||
bias = ctx.kwargs["bias"]
|
||||
|
@ -77,6 +77,7 @@ def swap_tensor(
|
||||
setattr(module, name, tensor)
|
||||
elif hasattr(module, name):
|
||||
delattr(module, name)
|
||||
# pyrefly: ignore # bad-return
|
||||
return orig_tensor
|
||||
|
||||
|
||||
|
@ -41,9 +41,11 @@ def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
|
||||
def _no_grad_wrapper(*args, **kwargs):
|
||||
with torch.no_grad():
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
return func(*args, **kwargs)
|
||||
|
||||
functools.update_wrapper(_no_grad_wrapper, func)
|
||||
# pyrefly: ignore # bad-return
|
||||
return _no_grad_wrapper
|
||||
|
||||
|
||||
|
@ -84,6 +84,7 @@ def convert_conv2d_weight_memory_format(
|
||||
)
|
||||
for child in module.children():
|
||||
convert_conv2d_weight_memory_format(child, memory_format)
|
||||
# pyrefly: ignore # bad-return
|
||||
return module
|
||||
|
||||
|
||||
@ -163,6 +164,7 @@ def convert_conv3d_weight_memory_format(
|
||||
)
|
||||
for child in module.children():
|
||||
convert_conv3d_weight_memory_format(child, memory_format)
|
||||
# pyrefly: ignore # bad-return
|
||||
return module
|
||||
|
||||
|
||||
|
@ -98,6 +98,7 @@ class _Orthogonal(Module):
|
||||
)
|
||||
# Q is now orthogonal (or unitary) of size (..., n, n)
|
||||
if n != k:
|
||||
# pyrefly: ignore # unbound-name
|
||||
Q = Q[..., :k]
|
||||
# Q is now the size of the X (albeit perhaps transposed)
|
||||
else:
|
||||
|
@ -179,23 +179,28 @@ class ParametrizationList(ModuleList):
|
||||
|
||||
# Register the tensor(s)
|
||||
if self.is_tensor:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if original.dtype != new.dtype:
|
||||
raise ValueError(
|
||||
"When `right_inverse` outputs one tensor, it may not change the dtype.\n"
|
||||
f"original.dtype: {original.dtype}\n"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f"right_inverse(original).dtype: {new.dtype}"
|
||||
)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if original.device != new.device:
|
||||
raise ValueError(
|
||||
"When `right_inverse` outputs one tensor, it may not change the device.\n"
|
||||
f"original.device: {original.device}\n"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
f"right_inverse(original).device: {new.device}"
|
||||
)
|
||||
|
||||
# Set the original to original so that the user does not need to re-register the parameter
|
||||
# manually in the optimiser
|
||||
with torch.no_grad():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
_maybe_set(original, new)
|
||||
_register_parameter_or_buffer(self, "original", original)
|
||||
else:
|
||||
@ -396,6 +401,7 @@ def _inject_property(module: Module, tensor_name: str) -> None:
|
||||
if torch.jit.is_scripting():
|
||||
raise RuntimeError("Parametrization is not working with scripting.")
|
||||
parametrization = self.parametrizations[tensor_name]
|
||||
# pyrefly: ignore # redundant-condition
|
||||
if _cache_enabled:
|
||||
if torch.jit.is_scripting():
|
||||
# Scripting
|
||||
@ -695,6 +701,7 @@ def remove_parametrizations(
|
||||
# Fetch the original tensor
|
||||
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
|
||||
parametrizations = module.parametrizations[tensor_name]
|
||||
# pyrefly: ignore # invalid-argument
|
||||
if parametrizations.is_tensor:
|
||||
original = parametrizations.original
|
||||
assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor"
|
||||
|
@ -274,8 +274,11 @@ class PruningContainer(BasePruningMethod):
|
||||
if not isinstance(args, Iterable): # only 1 item
|
||||
self._tensor_name = args._tensor_name
|
||||
self.add_pruning_method(args)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
elif len(args) == 1: # only 1 item in a tuple
|
||||
# pyrefly: ignore # index-error
|
||||
self._tensor_name = args[0]._tensor_name
|
||||
# pyrefly: ignore # index-error
|
||||
self.add_pruning_method(args[0])
|
||||
else: # manual construction from list or other iterable (or no args)
|
||||
for method in args:
|
||||
@ -1097,6 +1100,7 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw
|
||||
|
||||
# flatten importance scores to consider them all at once in global pruning
|
||||
relevant_importance_scores = torch.nn.utils.parameters_to_vector(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
[
|
||||
importance_scores.get((module, name), getattr(module, name))
|
||||
for (module, name) in parameters
|
||||
|
@ -332,6 +332,7 @@ def spectral_norm(
|
||||
else:
|
||||
dim = 0
|
||||
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
||||
# pyrefly: ignore # bad-return
|
||||
return module
|
||||
|
||||
|
||||
|
@ -41,10 +41,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
|
||||
)
|
||||
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
import optree
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
|
||||
|
@ -679,7 +679,7 @@ class FlopCounterMode:
|
||||
|
||||
|
||||
import tabulate
|
||||
# pyrefly: ignore # bad-assignment
|
||||
|
||||
tabulate.PRESERVE_WHITESPACE = True
|
||||
header = ["Module", "FLOP", "% Total"]
|
||||
values = []
|
||||
|
@ -9,7 +9,7 @@ from typing import Any, Optional
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
|
||||
from google.protobuf import struct_pb2
|
||||
|
||||
from tensorboard.compat.proto.summary_pb2 import (
|
||||
|
@ -956,7 +956,7 @@ class SummaryWriter:
|
||||
)
|
||||
self._projector_config.embeddings.extend([embedding_info])
|
||||
|
||||
# pyrefly: ignore # import-error
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
config_pbtxt = text_format.MessageToString(self._projector_config)
|
||||
|
Reference in New Issue
Block a user