mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pyrefly suppressions 7/n (#164913)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
12d2ef557f
commit
c855f8632e
@ -22,14 +22,16 @@ project-excludes = [
|
|||||||
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
||||||
"torch/_inductor/**",
|
"torch/_inductor/**",
|
||||||
"torch/distributed/**",
|
"torch/distributed/**",
|
||||||
"torch/nn/**",
|
|
||||||
"torch/_dynamo/**",
|
|
||||||
# formatting issues
|
# formatting issues
|
||||||
"torch/linalg/__init__.py",
|
"torch/linalg/__init__.py",
|
||||||
"torch/package/importer.py",
|
"torch/package/importer.py",
|
||||||
"torch/package/_package_pickler.py",
|
"torch/package/_package_pickler.py",
|
||||||
"torch/jit/annotations.py",
|
"torch/jit/annotations.py",
|
||||||
"torch/utils/data/datapipes/_typing.py",
|
"torch/utils/data/datapipes/_typing.py",
|
||||||
|
"torch/nn/functional.py",
|
||||||
|
"torch/_export/utils.py",
|
||||||
|
"torch/fx/experimental/unification/multipledispatch/__init__.py",
|
||||||
|
"torch/nn/modules/__init__.py",
|
||||||
# ====
|
# ====
|
||||||
"benchmarks/instruction_counts/main.py",
|
"benchmarks/instruction_counts/main.py",
|
||||||
"benchmarks/instruction_counts/definitions/setup.py",
|
"benchmarks/instruction_counts/definitions/setup.py",
|
||||||
|
@ -59,7 +59,6 @@ class TestBundledInputs(TestCase):
|
|||||||
# despite having nominally large bundled inputs.
|
# despite having nominally large bundled inputs.
|
||||||
augmented_size = model_size(sm)
|
augmented_size = model_size(sm)
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertLess(augmented_size, original_size + (1 << 12))
|
self.assertLess(augmented_size, original_size + (1 << 12))
|
||||||
|
|
||||||
loaded = save_and_load(sm)
|
loaded = save_and_load(sm)
|
||||||
@ -67,15 +66,12 @@ class TestBundledInputs(TestCase):
|
|||||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||||
self.assertEqual(len(inflated), len(samples))
|
self.assertEqual(len(inflated), len(samples))
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||||
|
|
||||||
for idx, inp in enumerate(inflated):
|
for idx, inp in enumerate(inflated):
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertIsInstance(inp, tuple)
|
self.assertIsInstance(inp, tuple)
|
||||||
self.assertEqual(len(inp), 1)
|
self.assertEqual(len(inp), 1)
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertIsInstance(inp[0], torch.Tensor)
|
self.assertIsInstance(inp[0], torch.Tensor)
|
||||||
if idx != 5:
|
if idx != 5:
|
||||||
# Strides might be important for benchmarking.
|
# Strides might be important for benchmarking.
|
||||||
@ -144,7 +140,6 @@ class TestBundledInputs(TestCase):
|
|||||||
inflated = loaded.get_all_bundled_inputs()
|
inflated = loaded.get_all_bundled_inputs()
|
||||||
self.assertEqual(inflated, samples)
|
self.assertEqual(inflated, samples)
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertTrue(loaded(*inflated[0]) == "first 1")
|
self.assertTrue(loaded(*inflated[0]) == "first 1")
|
||||||
|
|
||||||
def test_multiple_methods_with_inputs(self):
|
def test_multiple_methods_with_inputs(self):
|
||||||
@ -192,7 +187,6 @@ class TestBundledInputs(TestCase):
|
|||||||
|
|
||||||
# Check running and size helpers
|
# Check running and size helpers
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||||
|
|
||||||
@ -426,7 +420,6 @@ class TestBundledInputs(TestCase):
|
|||||||
augmented_size = model_size(sm)
|
augmented_size = model_size(sm)
|
||||||
# assert the size has not increased more than 8KB
|
# assert the size has not increased more than 8KB
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertLess(augmented_size, original_size + (1 << 13))
|
self.assertLess(augmented_size, original_size + (1 << 13))
|
||||||
|
|
||||||
loaded = save_and_load(sm)
|
loaded = save_and_load(sm)
|
||||||
|
@ -48,7 +48,7 @@ class TestComplexTensor(TestCase):
|
|||||||
def test_all(self, device, dtype):
|
def test_all(self, device, dtype):
|
||||||
# issue: https://github.com/pytorch/pytorch/issues/120875
|
# issue: https://github.com/pytorch/pytorch/issues/120875
|
||||||
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
|
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertTrue(torch.all(x))
|
self.assertTrue(torch.all(x))
|
||||||
|
|
||||||
@dtypes(*complex_types())
|
@dtypes(*complex_types())
|
||||||
@ -57,7 +57,7 @@ class TestComplexTensor(TestCase):
|
|||||||
x = torch.tensor(
|
x = torch.tensor(
|
||||||
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
|
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertFalse(torch.any(x))
|
self.assertFalse(torch.any(x))
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
|
@ -142,7 +142,6 @@ class TestTypeHints(TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
if result != 0:
|
if result != 0:
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.fail(f"mypy failed:\n{stderr}\n{stdout}")
|
self.fail(f"mypy failed:\n{stderr}\n{stdout}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase):
|
|||||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||||
# If reference count is leaked this would be a set of 10 elements
|
# If reference count is leaked this would be a set of 10 elements
|
||||||
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
|
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertLess(len(ref_cnt), 3)
|
self.assertLess(len(ref_cnt), 3)
|
||||||
|
|
||||||
self.assertEqual(torch.float64.to_complex(), torch.complex128)
|
self.assertEqual(torch.float64.to_complex(), torch.complex128)
|
||||||
@ -136,7 +136,7 @@ class TestDTypeInfo(TestCase):
|
|||||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||||
# If reference count is leaked this would be a set of 10 elements
|
# If reference count is leaked this would be a set of 10 elements
|
||||||
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
|
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertLess(len(ref_cnt), 3)
|
self.assertLess(len(ref_cnt), 3)
|
||||||
|
|
||||||
self.assertEqual(torch.complex128.to_real(), torch.double)
|
self.assertEqual(torch.complex128.to_real(), torch.double)
|
||||||
|
@ -53,6 +53,8 @@ from .eval_frame import (
|
|||||||
OptimizedModule,
|
OptimizedModule,
|
||||||
reset_code,
|
reset_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # deprecated
|
||||||
from .external_utils import is_compiling
|
from .external_utils import is_compiling
|
||||||
from .mutation_guard import GenerationTracker
|
from .mutation_guard import GenerationTracker
|
||||||
from .pgo import reset_code_state
|
from .pgo import reset_code_state
|
||||||
|
@ -95,6 +95,7 @@ class ModIndex(torch.autograd.Function):
|
|||||||
generate_vmap_rule = True
|
generate_vmap_rule = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(x: Tensor, indices: list[Tensor]) -> Tensor:
|
def forward(x: Tensor, indices: list[Tensor]) -> Tensor:
|
||||||
return torch.ops.aten.index(x, indices)
|
return torch.ops.aten.index(x, indices)
|
||||||
|
|
||||||
@ -242,6 +243,7 @@ def _trace_wrapped_functionalized(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
|||||||
|
|
||||||
def autograd_function_backward_rewritten(original_backward: Any) -> Any:
|
def autograd_function_backward_rewritten(original_backward: Any) -> Any:
|
||||||
def new_backward(ctx: Any, *grads: Any) -> Any:
|
def new_backward(ctx: Any, *grads: Any) -> Any:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
grads = [g.contiguous() for g in grads]
|
grads = [g.contiguous() for g in grads]
|
||||||
return original_backward(ctx, *grads)
|
return original_backward(ctx, *grads)
|
||||||
|
|
||||||
|
@ -89,6 +89,7 @@ class AOTCompiledFunction:
|
|||||||
**import_sources,
|
**import_sources,
|
||||||
self._artifacts.backend_id: self._artifacts.compiled_fn,
|
self._artifacts.backend_id: self._artifacts.compiled_fn,
|
||||||
}
|
}
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.fn = types.FunctionType(
|
self.fn = types.FunctionType(
|
||||||
self._artifacts.bytecode, f_globals, closure=self._artifacts.closure
|
self._artifacts.bytecode, f_globals, closure=self._artifacts.closure
|
||||||
)
|
)
|
||||||
|
@ -206,6 +206,7 @@ def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any])
|
|||||||
assert manager is not None
|
assert manager is not None
|
||||||
|
|
||||||
def fn(inputs: list[Any]) -> Any:
|
def fn(inputs: list[Any]) -> Any:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
manager.set_to_running_backward()
|
manager.set_to_running_backward()
|
||||||
return aot_model(inputs)
|
return aot_model(inputs)
|
||||||
|
|
||||||
|
@ -77,16 +77,19 @@ def tvm(
|
|||||||
opt_level = options.get("opt_level", 3)
|
opt_level = options.get("opt_level", 3)
|
||||||
|
|
||||||
if scheduler == "auto_scheduler":
|
if scheduler == "auto_scheduler":
|
||||||
|
# pyrefly: ignore # import-error
|
||||||
from tvm import auto_scheduler
|
from tvm import auto_scheduler
|
||||||
|
|
||||||
log_file = tempfile.NamedTemporaryFile()
|
log_file = tempfile.NamedTemporaryFile()
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if not os.path.exists(log_file):
|
if not os.path.exists(log_file):
|
||||||
tasks, task_weights = auto_scheduler.extract_tasks(
|
tasks, task_weights = auto_scheduler.extract_tasks(
|
||||||
mod["main"], params, target
|
mod["main"], params, target
|
||||||
)
|
)
|
||||||
if len(tasks) != 0:
|
if len(tasks) != 0:
|
||||||
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
|
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if not os.path.exists(log_file):
|
if not os.path.exists(log_file):
|
||||||
assert trials > 0
|
assert trials > 0
|
||||||
tune_option = auto_scheduler.TuningOptions(
|
tune_option = auto_scheduler.TuningOptions(
|
||||||
@ -97,7 +100,9 @@ def tvm(
|
|||||||
try:
|
try:
|
||||||
tuner.tune(tune_option)
|
tuner.tune(tune_option)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if os.path.exists(log_file):
|
if os.path.exists(log_file):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
os.unlink(log_file)
|
os.unlink(log_file)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -107,6 +112,7 @@ def tvm(
|
|||||||
):
|
):
|
||||||
lib = relay.build(mod, target=target, params=params)
|
lib = relay.build(mod, target=target, params=params)
|
||||||
elif scheduler == "meta_schedule":
|
elif scheduler == "meta_schedule":
|
||||||
|
# pyrefly: ignore # import-error
|
||||||
from tvm import meta_schedule as ms
|
from tvm import meta_schedule as ms
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as work_dir:
|
with tempfile.TemporaryDirectory() as work_dir:
|
||||||
|
@ -37,6 +37,7 @@ if sys.version_info >= (3, 11):
|
|||||||
TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
|
TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
|
||||||
else:
|
else:
|
||||||
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
|
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
if (3, 12) <= sys.version_info < (3, 14):
|
if (3, 12) <= sys.version_info < (3, 14):
|
||||||
TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
|
TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
|
||||||
if sys.version_info >= (3, 13):
|
if sys.version_info >= (3, 13):
|
||||||
|
@ -903,6 +903,7 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None:
|
|||||||
inst.arg = abs(
|
inst.arg = abs(
|
||||||
int(target.offset - inst.offset - instruction_size(inst))
|
int(target.offset - inst.offset - instruction_size(inst))
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
inst.arg //= 2
|
inst.arg //= 2
|
||||||
inst.argval = target.offset
|
inst.argval = target.offset
|
||||||
inst.argrepr = f"to {target.offset}"
|
inst.argrepr = f"to {target.offset}"
|
||||||
@ -1354,6 +1355,7 @@ def update_offsets(instructions: Sequence[Instruction]) -> None:
|
|||||||
offset = 0
|
offset = 0
|
||||||
for inst in instructions:
|
for inst in instructions:
|
||||||
inst.offset = offset
|
inst.offset = offset
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
offset += instruction_size(inst)
|
offset += instruction_size(inst)
|
||||||
|
|
||||||
|
|
||||||
|
@ -464,6 +464,7 @@ def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
|||||||
try:
|
try:
|
||||||
prof.enable()
|
prof.enable()
|
||||||
start_ts = time.time()
|
start_ts = time.time()
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
retval = prof.runcall(func, *args, **kwargs)
|
retval = prof.runcall(func, *args, **kwargs)
|
||||||
profile_latency = time.time() - start_ts
|
profile_latency = time.time() - start_ts
|
||||||
prof.disable()
|
prof.disable()
|
||||||
@ -957,6 +958,7 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
|||||||
if isinstance(mod, torch.nn.Module):
|
if isinstance(mod, torch.nn.Module):
|
||||||
mod = mod.forward
|
mod = mod.forward
|
||||||
if hasattr(mod, "__self__"):
|
if hasattr(mod, "__self__"):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return mod.__func__, mod.__self__
|
return mod.__func__, mod.__self__
|
||||||
elif inspect.isfunction(mod):
|
elif inspect.isfunction(mod):
|
||||||
return mod, None
|
return mod, None
|
||||||
@ -1096,6 +1098,7 @@ def _fullgraph_capture_frame(
|
|||||||
while cur_exn.__cause__ is not None:
|
while cur_exn.__cause__ is not None:
|
||||||
cur_exn.__cause__.with_traceback(None)
|
cur_exn.__cause__.with_traceback(None)
|
||||||
cur_exn = cur_exn.__cause__
|
cur_exn = cur_exn.__cause__
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
raise e.with_traceback(None) from e.__cause__ # User compiler error
|
raise e.with_traceback(None) from e.__cause__ # User compiler error
|
||||||
|
|
||||||
return CaptureOutput(
|
return CaptureOutput(
|
||||||
@ -1119,6 +1122,7 @@ def compile_frame( # type: ignore[return]
|
|||||||
frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
|
frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
|
||||||
distributed_state: Optional[DistributedState] = None,
|
distributed_state: Optional[DistributedState] = None,
|
||||||
package: Optional[CompilePackage] = None,
|
package: Optional[CompilePackage] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> DynamoOutput:
|
) -> DynamoOutput:
|
||||||
"""
|
"""
|
||||||
A helper function taking a frame and backend, then return the generated bytecode
|
A helper function taking a frame and backend, then return the generated bytecode
|
||||||
|
@ -20,6 +20,7 @@ allowed to compute gradients on).
|
|||||||
|
|
||||||
class TracableCreateParameter(torch.autograd.Function):
|
class TracableCreateParameter(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter:
|
def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter:
|
||||||
assert not tensor.requires_grad
|
assert not tensor.requires_grad
|
||||||
return placeholder.set_(tensor)
|
return placeholder.set_(tensor)
|
||||||
|
@ -879,6 +879,7 @@ def aot_graph_input_parser(
|
|||||||
data_type, shape_str = match.groups()
|
data_type, shape_str = match.groups()
|
||||||
shape = tuple(shape_str.split(","))
|
shape = tuple(shape_str.split(","))
|
||||||
dtype = dtype_map[data_type]
|
dtype = dtype_map[data_type]
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
kwargs[param] = gen_tensor(shape, dtype)
|
kwargs[param] = gen_tensor(shape, dtype)
|
||||||
|
|
||||||
match = re.search(sym_shape_regex, annotation)
|
match = re.search(sym_shape_regex, annotation)
|
||||||
@ -892,6 +893,7 @@ def aot_graph_input_parser(
|
|||||||
attr_name, data_type, shape_str, _ = match.groups()
|
attr_name, data_type, shape_str, _ = match.groups()
|
||||||
shape = tuple(shape_str.split(","))
|
shape = tuple(shape_str.split(","))
|
||||||
dtype = dtype_map[data_type]
|
dtype = dtype_map[data_type]
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
setattr(container, attr_name, gen_tensor(shape, dtype))
|
setattr(container, attr_name, gen_tensor(shape, dtype))
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
@ -95,6 +95,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig
|
|||||||
nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined]
|
nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined]
|
||||||
nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined]
|
nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined]
|
||||||
nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
|
nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return nonrecursive_disable_wrapper
|
return nonrecursive_disable_wrapper
|
||||||
|
|
||||||
if fn is None:
|
if fn is None:
|
||||||
@ -306,6 +307,7 @@ def forbid_in_graph(fn: Any) -> Any:
|
|||||||
if isinstance(fn, (list, tuple)):
|
if isinstance(fn, (list, tuple)):
|
||||||
return [forbid_in_graph(x) for x in fn]
|
return [forbid_in_graph(x) for x in fn]
|
||||||
assert callable(fn), "forbid_in_graph applies only to callables"
|
assert callable(fn), "forbid_in_graph applies only to callables"
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
fn._dynamo_forbidden = True
|
fn._dynamo_forbidden = True
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
@ -653,21 +655,28 @@ def mark_dynamic(
|
|||||||
|
|
||||||
if isinstance(index, int):
|
if isinstance(index, int):
|
||||||
if not hasattr(t, "_dynamo_dynamic_indices"):
|
if not hasattr(t, "_dynamo_dynamic_indices"):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._dynamo_dynamic_indices = set()
|
t._dynamo_dynamic_indices = set()
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._dynamo_dynamic_range = set()
|
t._dynamo_dynamic_range = set()
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._dynamo_hint_overrides = {}
|
t._dynamo_hint_overrides = {}
|
||||||
|
|
||||||
if not hasattr(t, "_specialize_on"):
|
if not hasattr(t, "_specialize_on"):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._specialize_on = {}
|
t._specialize_on = {}
|
||||||
|
|
||||||
if hint_override:
|
if hint_override:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._dynamo_hint_overrides[index] = hint_override
|
t._dynamo_hint_overrides[index] = hint_override
|
||||||
# TODO(voz): Should we bounds check?
|
# TODO(voz): Should we bounds check?
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._dynamo_dynamic_indices.add(index)
|
t._dynamo_dynamic_indices.add(index)
|
||||||
t._dynamo_dynamic_range.add(_DimRange(index, min, max)) # type: ignore[arg-type]
|
t._dynamo_dynamic_range.add(_DimRange(index, min, max)) # type: ignore[arg-type]
|
||||||
|
|
||||||
# FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies:
|
# FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies:
|
||||||
# TypeError: 'Attribute' object does not support item assignment
|
# TypeError: 'Attribute' object does not support item assignment
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if isinstance(t._specialize_on, dict):
|
if isinstance(t._specialize_on, dict):
|
||||||
t._specialize_on[index] = specialize_on if specialize_on is not None else []
|
t._specialize_on[index] = specialize_on if specialize_on is not None else []
|
||||||
|
|
||||||
@ -692,8 +701,10 @@ def maybe_mark_dynamic(t: Any, index: Union[int, list[Any], tuple[Any]]) -> None
|
|||||||
|
|
||||||
if isinstance(index, int):
|
if isinstance(index, int):
|
||||||
if not hasattr(t, "_dynamo_weak_dynamic_indices"):
|
if not hasattr(t, "_dynamo_weak_dynamic_indices"):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._dynamo_weak_dynamic_indices = set()
|
t._dynamo_weak_dynamic_indices = set()
|
||||||
# TODO(voz): Should we bounds check?
|
# TODO(voz): Should we bounds check?
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._dynamo_weak_dynamic_indices.add(index)
|
t._dynamo_weak_dynamic_indices.add(index)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -745,8 +756,11 @@ def mark_static(
|
|||||||
# TODO: Make this configurable via a supported public API
|
# TODO: Make this configurable via a supported public API
|
||||||
_apply_func_to_inner_tensors_of_same_dim(mark_static, t, index)
|
_apply_func_to_inner_tensors_of_same_dim(mark_static, t, index)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module):
|
if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._dynamo_marked_static = True
|
t._dynamo_marked_static = True
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return t
|
return t
|
||||||
|
|
||||||
if not isinstance(t, torch.Tensor):
|
if not isinstance(t, torch.Tensor):
|
||||||
|
@ -205,6 +205,7 @@ class CudaInterface(DeviceInterface):
|
|||||||
Event = torch.cuda.Event # type: ignore[assignment]
|
Event = torch.cuda.Event # type: ignore[assignment]
|
||||||
Stream = torch.cuda.Stream # type: ignore[assignment]
|
Stream = torch.cuda.Stream # type: ignore[assignment]
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
class Worker:
|
class Worker:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_device(device: int) -> None:
|
def set_device(device: int) -> None:
|
||||||
@ -240,6 +241,7 @@ class CudaInterface(DeviceInterface):
|
|||||||
set_device = staticmethod(torch.cuda.set_device)
|
set_device = staticmethod(torch.cuda.set_device)
|
||||||
device_count = staticmethod(torch.cuda.device_count)
|
device_count = staticmethod(torch.cuda.device_count)
|
||||||
stream = staticmethod(torch.cuda.stream) # type: ignore[assignment]
|
stream = staticmethod(torch.cuda.stream) # type: ignore[assignment]
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
current_stream = staticmethod(torch.cuda.current_stream)
|
current_stream = staticmethod(torch.cuda.current_stream)
|
||||||
set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment]
|
set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment]
|
||||||
_set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment]
|
_set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment]
|
||||||
@ -300,6 +302,7 @@ class MtiaInterface(DeviceInterface):
|
|||||||
Event = torch.mtia.Event # type: ignore[assignment]
|
Event = torch.mtia.Event # type: ignore[assignment]
|
||||||
Stream = torch.mtia.Stream # type: ignore[assignment]
|
Stream = torch.mtia.Stream # type: ignore[assignment]
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
class Worker:
|
class Worker:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_device(device: int) -> None:
|
def set_device(device: int) -> None:
|
||||||
@ -335,6 +338,7 @@ class MtiaInterface(DeviceInterface):
|
|||||||
set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment]
|
set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment]
|
||||||
device_count = staticmethod(torch.mtia.device_count)
|
device_count = staticmethod(torch.mtia.device_count)
|
||||||
stream = staticmethod(torch.mtia.stream) # type: ignore[assignment]
|
stream = staticmethod(torch.mtia.stream) # type: ignore[assignment]
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
current_stream = staticmethod(torch.mtia.current_stream)
|
current_stream = staticmethod(torch.mtia.current_stream)
|
||||||
set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment]
|
set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment]
|
||||||
_set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment]
|
_set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment]
|
||||||
@ -381,6 +385,7 @@ class XpuInterface(DeviceInterface):
|
|||||||
Event = torch.xpu.Event # type: ignore[assignment]
|
Event = torch.xpu.Event # type: ignore[assignment]
|
||||||
Stream = torch.xpu.Stream # type: ignore[assignment]
|
Stream = torch.xpu.Stream # type: ignore[assignment]
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
class Worker:
|
class Worker:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_device(device: int) -> None:
|
def set_device(device: int) -> None:
|
||||||
@ -416,6 +421,7 @@ class XpuInterface(DeviceInterface):
|
|||||||
set_device = staticmethod(torch.xpu.set_device)
|
set_device = staticmethod(torch.xpu.set_device)
|
||||||
device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type]
|
device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type]
|
||||||
stream = staticmethod(torch.xpu.stream) # type: ignore[assignment]
|
stream = staticmethod(torch.xpu.stream) # type: ignore[assignment]
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
current_stream = staticmethod(torch.xpu.current_stream)
|
current_stream = staticmethod(torch.xpu.current_stream)
|
||||||
set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment]
|
set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment]
|
||||||
_set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment]
|
_set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment]
|
||||||
@ -458,6 +464,7 @@ class CpuDeviceProperties:
|
|||||||
|
|
||||||
|
|
||||||
class CpuInterface(DeviceInterface):
|
class CpuInterface(DeviceInterface):
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
class Event(torch.Event):
|
class Event(torch.Event):
|
||||||
def __init__(self, enable_timing: bool = True) -> None:
|
def __init__(self, enable_timing: bool = True) -> None:
|
||||||
self.time = 0.0
|
self.time = 0.0
|
||||||
@ -468,6 +475,7 @@ class CpuInterface(DeviceInterface):
|
|||||||
def record(self, stream: Any = None) -> None:
|
def record(self, stream: Any = None) -> None:
|
||||||
self.time = time.perf_counter()
|
self.time = time.perf_counter()
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
class Worker:
|
class Worker:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_device_properties(
|
def get_device_properties(
|
||||||
@ -543,6 +551,7 @@ class MpsInterface(DeviceInterface):
|
|||||||
def synchronize(device: torch.types.Device = None) -> None:
|
def synchronize(device: torch.types.Device = None) -> None:
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
class Worker:
|
class Worker:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_device_properties(device: torch.types.Device = None) -> Any:
|
def get_device_properties(device: torch.types.Device = None) -> Any:
|
||||||
|
@ -484,6 +484,7 @@ class OptimizedModule(torch.nn.Module):
|
|||||||
self._initialize()
|
self._initialize()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def training(self) -> bool:
|
def training(self) -> bool:
|
||||||
return self._orig_mod.training
|
return self._orig_mod.training
|
||||||
|
|
||||||
@ -892,6 +893,7 @@ class _TorchDynamoContext:
|
|||||||
while cur_exn.__cause__ is not None:
|
while cur_exn.__cause__ is not None:
|
||||||
cur_exn.__cause__.with_traceback(None)
|
cur_exn.__cause__.with_traceback(None)
|
||||||
cur_exn = cur_exn.__cause__
|
cur_exn = cur_exn.__cause__
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
raise e.with_traceback(None) from e.__cause__ # User compiler error
|
raise e.with_traceback(None) from e.__cause__ # User compiler error
|
||||||
except ShortenTraceback as e:
|
except ShortenTraceback as e:
|
||||||
# Failures in the backend likely don't have useful
|
# Failures in the backend likely don't have useful
|
||||||
@ -1020,7 +1022,10 @@ class OptimizeContext(_TorchDynamoContext):
|
|||||||
assert rebuild_ctx is not None
|
assert rebuild_ctx is not None
|
||||||
compiler_fn = rebuild_ctx()
|
compiler_fn = rebuild_ctx()
|
||||||
ctx = torch._dynamo.compiled_autograd._enable(
|
ctx = torch._dynamo.compiled_autograd._enable(
|
||||||
compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False
|
compiler_fn,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
dynamic=_dynamic,
|
||||||
|
ignore_active_disable_ctx=False,
|
||||||
)
|
)
|
||||||
ctx.__enter__()
|
ctx.__enter__()
|
||||||
return functools.partial(ctx.__exit__, None, None, None)
|
return functools.partial(ctx.__exit__, None, None, None)
|
||||||
@ -1083,6 +1088,7 @@ class DisableContext(_TorchDynamoContext):
|
|||||||
cls_obj.__call__ = self(cls_obj.__call__)
|
cls_obj.__call__ = self(cls_obj.__call__)
|
||||||
if issubclass(cls_obj, torch.nn.Module):
|
if issubclass(cls_obj, torch.nn.Module):
|
||||||
# NN module variable tracker directly inlines the _call_impl. Disable it.
|
# NN module variable tracker directly inlines the _call_impl. Disable it.
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
cls_obj._call_impl = self(cls_obj._call_impl)
|
cls_obj._call_impl = self(cls_obj._call_impl)
|
||||||
return cls_obj
|
return cls_obj
|
||||||
|
|
||||||
@ -1988,6 +1994,7 @@ def export(
|
|||||||
path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any]
|
path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if isinstance(t, torch.Tensor):
|
if isinstance(t, torch.Tensor):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return ambient_fake_mode.from_tensor(t, static_shapes=True)
|
return ambient_fake_mode.from_tensor(t, static_shapes=True)
|
||||||
elif isinstance(t, _IntWrapper):
|
elif isinstance(t, _IntWrapper):
|
||||||
if (
|
if (
|
||||||
@ -2068,8 +2075,11 @@ def export(
|
|||||||
)
|
)
|
||||||
and not trace_rules.check(call_to_inspect)
|
and not trace_rules.check(call_to_inspect)
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
dim_constraints.solve()
|
dim_constraints.solve()
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
forced_specializations = dim_constraints.forced_specializations()
|
forced_specializations = dim_constraints.forced_specializations()
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
msg = dim_constraints.prettify_results(
|
msg = dim_constraints.prettify_results(
|
||||||
original_signature,
|
original_signature,
|
||||||
dynamic_shapes,
|
dynamic_shapes,
|
||||||
@ -2090,9 +2100,11 @@ def export(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Error if we have any constraints on static values
|
# Error if we have any constraints on static values
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
for k in shape_env.var_to_range.keys():
|
for k in shape_env.var_to_range.keys():
|
||||||
if isinstance(k, sympy.Integer):
|
if isinstance(k, sympy.Integer):
|
||||||
constraint_violation_error = ConstraintViolationError(
|
constraint_violation_error = ConstraintViolationError(
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||||
"It appears that you're trying to set a constraint on a "
|
"It appears that you're trying to set a constraint on a "
|
||||||
f"value which we evaluated to have a static value of {k}. "
|
f"value which we evaluated to have a static value of {k}. "
|
||||||
|
@ -369,6 +369,7 @@ def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedExc
|
|||||||
observed_exception_map[exc_type] = type( # type: ignore[assignment]
|
observed_exception_map[exc_type] = type( # type: ignore[assignment]
|
||||||
f"Observed{name}Error", (ObservedException,), {}
|
f"Observed{name}Error", (ObservedException,), {}
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
return observed_exception_map[exc_type]
|
return observed_exception_map[exc_type]
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,7 +96,9 @@ def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
|
|||||||
args, kwargs = pytree.tree_map_only(
|
args, kwargs = pytree.tree_map_only(
|
||||||
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
|
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # invalid-param-spec
|
||||||
out = f(*args, **kwargs)
|
out = f(*args, **kwargs)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
|
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
|
||||||
|
|
||||||
return wrap
|
return wrap
|
||||||
|
@ -250,6 +250,7 @@ class DynamoGraphTransformer(torch.fx.Transformer):
|
|||||||
else:
|
else:
|
||||||
placeholder.node.meta["val"] = self.flat_inputs[i]
|
placeholder.node.meta["val"] = self.flat_inputs[i]
|
||||||
|
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
self.new_input_nodes[i] = placeholder
|
self.new_input_nodes[i] = placeholder
|
||||||
|
|
||||||
def _create_placeholder_mapping(self) -> None:
|
def _create_placeholder_mapping(self) -> None:
|
||||||
@ -324,12 +325,18 @@ class DynamoGraphTransformer(torch.fx.Transformer):
|
|||||||
|
|
||||||
# Copy module metadata like the original implementation
|
# Copy module metadata like the original implementation
|
||||||
if hasattr(self.module, "meta"):
|
if hasattr(self.module, "meta"):
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
if "dynamo_flat_name_to_original_fqn" in self.module.meta:
|
if "dynamo_flat_name_to_original_fqn" in self.module.meta:
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
|
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
"dynamo_flat_name_to_original_fqn"
|
"dynamo_flat_name_to_original_fqn"
|
||||||
]
|
]
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
if "dynamo_compile_id" in self.module.meta:
|
if "dynamo_compile_id" in self.module.meta:
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
result_gm.meta["dynamo_compile_id"] = self.module.meta[
|
result_gm.meta["dynamo_compile_id"] = self.module.meta[
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
"dynamo_compile_id"
|
"dynamo_compile_id"
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -361,8 +368,11 @@ def _suggest_or_raise_constraint_violation(
|
|||||||
torch._ops.OpOverloadPacket | torch._ops.OpOverload,
|
torch._ops.OpOverloadPacket | torch._ops.OpOverload,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
dim_constraints.solve()
|
dim_constraints.solve()
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
forced_specializations = dim_constraints.forced_specializations()
|
forced_specializations = dim_constraints.forced_specializations()
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
msg = dim_constraints.prettify_results(
|
msg = dim_constraints.prettify_results(
|
||||||
inspect.signature(orig_callable), # type: ignore[attr-defined]
|
inspect.signature(orig_callable), # type: ignore[attr-defined]
|
||||||
dynamic_shapes,
|
dynamic_shapes,
|
||||||
@ -383,9 +393,11 @@ def _suggest_or_raise_constraint_violation(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Error if we have any constraints on static values
|
# Error if we have any constraints on static values
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
for k in shape_env.var_to_range.keys():
|
for k in shape_env.var_to_range.keys():
|
||||||
if isinstance(k, sympy.Integer):
|
if isinstance(k, sympy.Integer):
|
||||||
constraint_violation_error = ConstraintViolationError(
|
constraint_violation_error = ConstraintViolationError(
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||||
"It appears that you're trying to set a constraint on a "
|
"It appears that you're trying to set a constraint on a "
|
||||||
f"value which we evaluated to have a static value of {k}. "
|
f"value which we evaluated to have a static value of {k}. "
|
||||||
|
@ -320,6 +320,7 @@ class GraphRegionTracker:
|
|||||||
if len(group) > 1:
|
if len(group) > 1:
|
||||||
region_group = []
|
region_group = []
|
||||||
min_rank = math.inf
|
min_rank = math.inf
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for node in group:
|
for node in group:
|
||||||
# some nodes aren't in the topo ranking?
|
# some nodes aren't in the topo ranking?
|
||||||
if node in topological_ranking:
|
if node in topological_ranking:
|
||||||
|
@ -640,6 +640,7 @@ class GuardManagerWrapper:
|
|||||||
if isinstance(guard, RelationalGuard):
|
if isinstance(guard, RelationalGuard):
|
||||||
if guard not in self.printed_relational_guards:
|
if guard not in self.printed_relational_guards:
|
||||||
self.printed_relational_guards.add(guard)
|
self.printed_relational_guards.add(guard)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
body.writelines(self.get_guard_lines(guard))
|
body.writelines(self.get_guard_lines(guard))
|
||||||
else:
|
else:
|
||||||
body.writelines(
|
body.writelines(
|
||||||
@ -700,6 +701,7 @@ class GuardManagerWrapper:
|
|||||||
for guard in mgr.get_leaf_guards():
|
for guard in mgr.get_leaf_guards():
|
||||||
if isinstance(guard, RelationalGuard):
|
if isinstance(guard, RelationalGuard):
|
||||||
if guard not in relational_guards_seen:
|
if guard not in relational_guards_seen:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.code_parts.extend(get_code_parts(guard))
|
self.code_parts.extend(get_code_parts(guard))
|
||||||
relational_guards_seen.add(guard)
|
relational_guards_seen.add(guard)
|
||||||
else:
|
else:
|
||||||
@ -716,6 +718,7 @@ def from_numpy(a: Any) -> torch.Tensor:
|
|||||||
# Re-enable torch function since we disable it on leaf guards
|
# Re-enable torch function since we disable it on leaf guards
|
||||||
# we need it to properly construct the tensor if a default device is set
|
# we need it to properly construct the tensor if a default device is set
|
||||||
with torch.overrides._enable_torch_function():
|
with torch.overrides._enable_torch_function():
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a
|
return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a
|
||||||
|
|
||||||
|
|
||||||
@ -729,6 +732,7 @@ def uninteresting_files() -> set[str]:
|
|||||||
|
|
||||||
from torch._dynamo.polyfills.loader import POLYFILLED_MODULES
|
from torch._dynamo.polyfills.loader import POLYFILLED_MODULES
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
mods.extend(POLYFILLED_MODULES)
|
mods.extend(POLYFILLED_MODULES)
|
||||||
|
|
||||||
return {inspect.getfile(m) for m in mods}
|
return {inspect.getfile(m) for m in mods}
|
||||||
@ -2205,6 +2209,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Python math library doesn't support complex nan, so we need to use numpy
|
# Python math library doesn't support complex nan, so we need to use numpy
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if istype(val, complex) and np.isnan(val):
|
if istype(val, complex) and np.isnan(val):
|
||||||
code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"]
|
code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"]
|
||||||
self._set_guard_export_info(guard, code)
|
self._set_guard_export_info(guard, code)
|
||||||
@ -2495,6 +2500,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
# sources for the corresponding tensor dimension.
|
# sources for the corresponding tensor dimension.
|
||||||
return [
|
return [
|
||||||
TensorPropertySource(source, TensorProperty.SIZE, dim)
|
TensorPropertySource(source, TensorProperty.SIZE, dim)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
for source in output_graph.tracked_fakes_id_to_source[t_id]
|
for source in output_graph.tracked_fakes_id_to_source[t_id]
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2531,6 +2537,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
equalities_inputs = None
|
equalities_inputs = None
|
||||||
|
|
||||||
def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]:
|
def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return output_graph.shape_env.produce_guards_verbose(
|
return output_graph.shape_env.produce_guards_verbose(
|
||||||
[a.fake for a in fs], # type: ignore[misc]
|
[a.fake for a in fs], # type: ignore[misc]
|
||||||
[a.source for a in fs],
|
[a.source for a in fs],
|
||||||
@ -2538,6 +2545,7 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
equalities_inputs=equalities_inputs,
|
equalities_inputs=equalities_inputs,
|
||||||
source_ref=self.source_ref,
|
source_ref=self.source_ref,
|
||||||
# Export keeps static.
|
# Export keeps static.
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
ignore_static=(not output_graph.export),
|
ignore_static=(not output_graph.export),
|
||||||
langs=langs,
|
langs=langs,
|
||||||
)
|
)
|
||||||
@ -2599,7 +2607,9 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
if not python_fallback:
|
if not python_fallback:
|
||||||
assert cpp_code_parts # type: ignore[possibly-undefined]
|
assert cpp_code_parts # type: ignore[possibly-undefined]
|
||||||
code_parts, source_to_symbol = (
|
code_parts, source_to_symbol = (
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
cpp_code_parts.exprs,
|
cpp_code_parts.exprs,
|
||||||
|
# pyrefly: ignore # unbound-name, missing-attribute
|
||||||
cpp_code_parts.source_to_symbol,
|
cpp_code_parts.source_to_symbol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2630,7 +2640,9 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
|
|
||||||
assert cpp_code_parts # type: ignore[possibly-undefined]
|
assert cpp_code_parts # type: ignore[possibly-undefined]
|
||||||
code_parts, source_to_symbol = (
|
code_parts, source_to_symbol = (
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
cpp_code_parts.exprs,
|
cpp_code_parts.exprs,
|
||||||
|
# pyrefly: ignore # unbound-name, missing-attribute
|
||||||
cpp_code_parts.source_to_symbol,
|
cpp_code_parts.source_to_symbol,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3240,6 +3252,7 @@ class GuardsStatePickler(pickle.Pickler):
|
|||||||
assert _.__closure__ is not None
|
assert _.__closure__ is not None
|
||||||
return _.__closure__[0]
|
return _.__closure__[0]
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def reducer_override(
|
def reducer_override(
|
||||||
self, obj: Any
|
self, obj: Any
|
||||||
) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]:
|
) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]:
|
||||||
@ -4072,10 +4085,13 @@ class CheckFunctionManager:
|
|||||||
and (cache_entry := self.guard_manager.cache_entry) is not None
|
and (cache_entry := self.guard_manager.cache_entry) is not None
|
||||||
and (extra_state := self.guard_manager.extra_state) is not None
|
and (extra_state := self.guard_manager.extra_state) is not None
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
assert isinstance(cache_entry, CacheEntry)
|
assert isinstance(cache_entry, CacheEntry)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
assert isinstance(extra_state, ExtraState)
|
assert isinstance(extra_state, ExtraState)
|
||||||
reason = f"Cache line invalidated because {obj_str} got deallocated"
|
reason = f"Cache line invalidated because {obj_str} got deallocated"
|
||||||
deleted_guard_manager = DeletedGuardManagerWrapper(reason)
|
deleted_guard_manager = DeletedGuardManagerWrapper(reason)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
extra_state.invalidate(cache_entry, deleted_guard_manager)
|
extra_state.invalidate(cache_entry, deleted_guard_manager)
|
||||||
self.guard_manager = deleted_guard_manager
|
self.guard_manager = deleted_guard_manager
|
||||||
|
|
||||||
|
@ -707,6 +707,7 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
self.backward_state_proxy: Optional[torch.fx.Proxy] = None
|
self.backward_state_proxy: Optional[torch.fx.Proxy] = None
|
||||||
self.backward_state_var: Optional[str] = None
|
self.backward_state_var: Optional[str] = None
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
self.name_of_builtins_dict_key_in_fglobals: str = (
|
self.name_of_builtins_dict_key_in_fglobals: str = (
|
||||||
self.install_builtins_dict_in_fglobals()
|
self.install_builtins_dict_in_fglobals()
|
||||||
)
|
)
|
||||||
@ -1146,6 +1147,7 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
|
vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
|
||||||
|
|
||||||
assert "tensor_dict" not in vt.as_proxy().node.meta
|
assert "tensor_dict" not in vt.as_proxy().node.meta
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target)
|
vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target)
|
||||||
|
|
||||||
return vt
|
return vt
|
||||||
@ -1157,6 +1159,7 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
install_guard(source.make_guard(GuardBuilder.NN_MODULE))
|
install_guard(source.make_guard(GuardBuilder.NN_MODULE))
|
||||||
|
|
||||||
def wrap_name(module_key: str) -> VariableTracker:
|
def wrap_name(module_key: str) -> VariableTracker:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return NNModuleVariable(type(target), module_key, target, **options)
|
return NNModuleVariable(type(target), module_key, target, **options)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -1970,7 +1973,9 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
tx = self.root_tx
|
tx = self.root_tx
|
||||||
assert tx is not None
|
assert tx is not None
|
||||||
if (ds := tx.distributed_state) is not None and ds.all_states is None:
|
if (ds := tx.distributed_state) is not None and ds.all_states is None:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
compile_pg = ds.compile_pg
|
compile_pg = ds.compile_pg
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
log.info("compiler_collective %s", ds.local_state)
|
log.info("compiler_collective %s", ds.local_state)
|
||||||
torch._logging.trace_structured(
|
torch._logging.trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
@ -1978,6 +1983,7 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
"name": "compiler_collective",
|
"name": "compiler_collective",
|
||||||
"encoding": "string",
|
"encoding": "string",
|
||||||
},
|
},
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
payload_fn=lambda: ds.local_state.render(),
|
payload_fn=lambda: ds.local_state.render(),
|
||||||
)
|
)
|
||||||
device_types = compile_pg._device_types
|
device_types = compile_pg._device_types
|
||||||
@ -1991,7 +1997,9 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
dynamo_timed("compiler_collective", log_pt2_compile_event=True),
|
dynamo_timed("compiler_collective", log_pt2_compile_event=True),
|
||||||
):
|
):
|
||||||
all_states: list[Any] = [None] * compile_pg.size()
|
all_states: list[Any] = [None] * compile_pg.size()
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
|
dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
ds.all_states = all_states
|
ds.all_states = all_states
|
||||||
# Clear speculation log, because are tracing may diverge due to
|
# Clear speculation log, because are tracing may diverge due to
|
||||||
# this information from the compiler collective
|
# this information from the compiler collective
|
||||||
@ -2321,6 +2329,7 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
return compiled_fn
|
return compiled_fn
|
||||||
|
|
||||||
def dedup_pass(self) -> dict[str, torch.fx.GraphModule]:
|
def dedup_pass(self) -> dict[str, torch.fx.GraphModule]:
|
||||||
@ -2375,6 +2384,7 @@ class OutputGraph(OutputGraphCommon):
|
|||||||
isinstance(b, torch.SymBool)
|
isinstance(b, torch.SymBool)
|
||||||
and (r := b.node.maybe_as_bool()) is not None
|
and (r := b.node.maybe_as_bool()) is not None
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
return r
|
return r
|
||||||
# TODO: We can also technically remove all cases when the input
|
# TODO: We can also technically remove all cases when the input
|
||||||
# doesn't have unbacked inputs, since it's all in the ShapeEnv
|
# doesn't have unbacked inputs, since it's all in the ShapeEnv
|
||||||
@ -2740,6 +2750,7 @@ def check_pt2_compliant_op(
|
|||||||
hints=[],
|
hints=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
op = getattr(target, overload)
|
op = getattr(target, overload)
|
||||||
if torch.Tag.pt2_compliant_tag in op.tags:
|
if torch.Tag.pt2_compliant_tag in op.tags:
|
||||||
encountered_compliant_op(op)
|
encountered_compliant_op(op)
|
||||||
@ -2747,6 +2758,7 @@ def check_pt2_compliant_op(
|
|||||||
encountered_non_compliant_op(
|
encountered_non_compliant_op(
|
||||||
op,
|
op,
|
||||||
f"Encountered the torch.ops.OpOverloadPacket {target} "
|
f"Encountered the torch.ops.OpOverloadPacket {target} "
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
f"which resolves to the overload ({overload}) that is "
|
f"which resolves to the overload ({overload}) that is "
|
||||||
f"not PT2 compliant.",
|
f"not PT2 compliant.",
|
||||||
)
|
)
|
||||||
@ -2767,6 +2779,7 @@ class LazyProxy:
|
|||||||
**kwargs: P.kwargs,
|
**kwargs: P.kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.tracer = tracer
|
self.tracer = tracer
|
||||||
|
# pyrefly: ignore # invalid-type-var
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
@ -319,6 +319,7 @@ def _get_code_source(code: types.CodeType) -> tuple[str, str]:
|
|||||||
code_source = _find_code_source(toplevel)
|
code_source = _find_code_source(toplevel)
|
||||||
if code_source is None:
|
if code_source is None:
|
||||||
_raise_resolution_error(code, toplevel)
|
_raise_resolution_error(code, toplevel)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return toplevel.__qualname__, code_source.strip(".")
|
return toplevel.__qualname__, code_source.strip(".")
|
||||||
|
|
||||||
|
|
||||||
@ -593,9 +594,11 @@ class CompilePackage:
|
|||||||
f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})"
|
f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self._source_info = dynamo.source_info
|
self._source_info = dynamo.source_info
|
||||||
|
|
||||||
main, *codes = dynamo.codes
|
main, *codes = dynamo.codes
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self._codes = {self._innermost_fn.__code__: main}
|
self._codes = {self._innermost_fn.__code__: main}
|
||||||
for code in codes:
|
for code in codes:
|
||||||
self._codes[SerializedCode.to_code_object(code.python_code)] = code
|
self._codes[SerializedCode.to_code_object(code.python_code)] = code
|
||||||
@ -603,6 +606,7 @@ class CompilePackage:
|
|||||||
self._add_function(
|
self._add_function(
|
||||||
self._innermost_fn.__code__, self._innermost_fn.__module__
|
self._innermost_fn.__code__, self._innermost_fn.__module__
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def _add_function(
|
def _add_function(
|
||||||
@ -746,6 +750,7 @@ class CompilePackage:
|
|||||||
for name in names:
|
for name in names:
|
||||||
module.__dict__.pop(name)
|
module.__dict__.pop(name)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self._installed_globals = {}
|
self._installed_globals = {}
|
||||||
|
|
||||||
_reset_precompile_entries(self._innermost_fn.__code__)
|
_reset_precompile_entries(self._innermost_fn.__code__)
|
||||||
|
@ -167,6 +167,7 @@ class CodeId:
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class CodeState:
|
class CodeState:
|
||||||
automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field(
|
automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field(
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
default_factory=lambda: defaultdict(FrameStateSizeEntry)
|
default_factory=lambda: defaultdict(FrameStateSizeEntry)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -851,6 +852,7 @@ def get_code_state() -> defaultdict[CodeId, CodeState]:
|
|||||||
not _CODE_STATE
|
not _CODE_STATE
|
||||||
and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
|
and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
extra_read_key = get_extra_cache_key(sticky_read)
|
extra_read_key = get_extra_cache_key(sticky_read)
|
||||||
if extra_read_key is not None:
|
if extra_read_key is not None:
|
||||||
get_extra_remote_code_state(extra_read_key)
|
get_extra_remote_code_state(extra_read_key)
|
||||||
|
@ -196,6 +196,7 @@ def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def zip_longest(
|
def zip_longest(
|
||||||
iter1: Iterable[_T1],
|
iter1: Iterable[_T1],
|
||||||
/,
|
/,
|
||||||
@ -205,6 +206,7 @@ def zip_longest(
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def zip_longest(
|
def zip_longest(
|
||||||
iter1: Iterable[_T1],
|
iter1: Iterable[_T1],
|
||||||
iter2: Iterable[_T2],
|
iter2: Iterable[_T2],
|
||||||
@ -213,6 +215,7 @@ def zip_longest(
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def zip_longest(
|
def zip_longest(
|
||||||
iter1: Iterable[_T1],
|
iter1: Iterable[_T1],
|
||||||
iter2: Iterable[_T2],
|
iter2: Iterable[_T2],
|
||||||
@ -223,6 +226,7 @@ def zip_longest(
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def zip_longest(
|
def zip_longest(
|
||||||
iter1: Iterable[_T],
|
iter1: Iterable[_T],
|
||||||
iter2: Iterable[_T],
|
iter2: Iterable[_T],
|
||||||
@ -233,6 +237,7 @@ def zip_longest(
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def zip_longest(
|
def zip_longest(
|
||||||
iter1: Iterable[_T],
|
iter1: Iterable[_T],
|
||||||
iter2: Iterable[_T],
|
iter2: Iterable[_T],
|
||||||
|
@ -30,10 +30,12 @@ _Us = TypeVarTuple("_Us")
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def attrgetter(attr: str, /) -> Callable[[Any], _U]: ...
|
def attrgetter(attr: str, /) -> Callable[[Any], _U]: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def attrgetter(
|
def attrgetter(
|
||||||
attr1: str, attr2: str, /, *attrs: str
|
attr1: str, attr2: str, /, *attrs: str
|
||||||
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...
|
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...
|
||||||
@ -68,10 +70,12 @@ def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]:
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def itemgetter(item: _T, /) -> Callable[[Any], _U]: ...
|
def itemgetter(item: _T, /) -> Callable[[Any], _U]: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def itemgetter(
|
def itemgetter(
|
||||||
item1: _T1, item2: _T2, /, *items: Unpack[_Ts]
|
item1: _T1, item2: _T2, /, *items: Unpack[_Ts]
|
||||||
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...
|
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...
|
||||||
|
@ -17,6 +17,7 @@ __all__ = ["fspath"]
|
|||||||
@substitute_in_graph(os.fspath, can_constant_fold_through=True)
|
@substitute_in_graph(os.fspath, can_constant_fold_through=True)
|
||||||
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
|
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
|
||||||
if isinstance(path, (str, bytes)):
|
if isinstance(path, (str, bytes)):
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return path
|
return path
|
||||||
|
|
||||||
path_type = type(path)
|
path_type = type(path)
|
||||||
|
@ -171,6 +171,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
|||||||
or optree.is_namedtuple_class(treespec.type)
|
or optree.is_namedtuple_class(treespec.type)
|
||||||
or optree.is_structseq_class(treespec.type)
|
or optree.is_structseq_class(treespec.type)
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return treespec._unflatten_func(
|
return treespec._unflatten_func(
|
||||||
treespec._metadata,
|
treespec._metadata,
|
||||||
children_representations,
|
children_representations,
|
||||||
|
@ -49,8 +49,11 @@ class ProfileMetrics:
|
|||||||
if isinstance(other, int):
|
if isinstance(other, int):
|
||||||
other = ProfileMetrics(other, other, other)
|
other = ProfileMetrics(other, other, other)
|
||||||
return ProfileMetrics(
|
return ProfileMetrics(
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
self.microseconds / max(1, other.microseconds),
|
self.microseconds / max(1, other.microseconds),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.operators / max(1, other.operators),
|
self.operators / max(1, other.operators),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.fusions / max(1, other.fusions),
|
self.fusions / max(1, other.fusions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -370,13 +370,16 @@ isolate_fails_code_str = None
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(kernel, Autotuner):
|
if isinstance(kernel, Autotuner):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if isinstance(kernel.fn, Heuristics):
|
if isinstance(kernel.fn, Heuristics):
|
||||||
model_str += "ERROR: Repro will not work as intended, "
|
model_str += "ERROR: Repro will not work as intended, "
|
||||||
model_str += "triton.runtime.autotuner.Heuristics is not currently supported\n"
|
model_str += "triton.runtime.autotuner.Heuristics is not currently supported\n"
|
||||||
break
|
break
|
||||||
|
|
||||||
config_strs = []
|
config_strs = []
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
for kernel_config in kernel.configs:
|
for kernel_config in kernel.configs:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
config_strs.append(f"""triton.Config(
|
config_strs.append(f"""triton.Config(
|
||||||
{str(kernel_config.kwargs)},
|
{str(kernel_config.kwargs)},
|
||||||
num_warps={kernel_config.num_warps},
|
num_warps={kernel_config.num_warps},
|
||||||
@ -394,8 +397,10 @@ isolate_fails_code_str = None
|
|||||||
""").strip()
|
""").strip()
|
||||||
|
|
||||||
model_str += "\n@triton.jit\n"
|
model_str += "\n@triton.jit\n"
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src
|
src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src
|
||||||
fn_name = (
|
fn_name = (
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
kernel._fn_name
|
kernel._fn_name
|
||||||
if isinstance(kernel, JITFunction)
|
if isinstance(kernel, JITFunction)
|
||||||
else kernel.fn._fn_name
|
else kernel.fn._fn_name
|
||||||
@ -409,7 +414,9 @@ isolate_fails_code_str = None
|
|||||||
model_str += "ERROR: Repro will not work as intended, "
|
model_str += "ERROR: Repro will not work as intended, "
|
||||||
model_str += f"User defined triton kernel exception: {e}\n"
|
model_str += f"User defined triton kernel exception: {e}\n"
|
||||||
|
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if len(kernel_side_table.constant_args) > 0:
|
if len(kernel_side_table.constant_args) > 0:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n"
|
model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n"
|
||||||
|
|
||||||
model_str += NNModuleToString.convert(gm)
|
model_str += NNModuleToString.convert(gm)
|
||||||
@ -420,8 +427,10 @@ isolate_fails_code_str = None
|
|||||||
# Extract from graph placeholders and their corresponding arguments
|
# Extract from graph placeholders and their corresponding arguments
|
||||||
placeholder_targets = fx_placeholder_targets(gm)
|
placeholder_targets = fx_placeholder_targets(gm)
|
||||||
for placeholder, arg in zip(placeholder_targets, args):
|
for placeholder, arg in zip(placeholder_targets, args):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if isinstance(arg, (int, torch.SymInt)):
|
if isinstance(arg, (int, torch.SymInt)):
|
||||||
writer.symint(placeholder, arg)
|
writer.symint(placeholder, arg)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
elif isinstance(arg, torch.Tensor):
|
elif isinstance(arg, torch.Tensor):
|
||||||
# TODO: improve these names with FQN
|
# TODO: improve these names with FQN
|
||||||
writer.tensor(placeholder, arg)
|
writer.tensor(placeholder, arg)
|
||||||
@ -431,16 +440,20 @@ isolate_fails_code_str = None
|
|||||||
writer.unsupported(placeholder, arg)
|
writer.unsupported(placeholder, arg)
|
||||||
|
|
||||||
# Extract symbolic variables from the same arguments
|
# Extract symbolic variables from the same arguments
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if isinstance(arg, torch.SymInt):
|
if isinstance(arg, torch.SymInt):
|
||||||
sym_name = str(arg.node)
|
sym_name = str(arg.node)
|
||||||
if arg.node.hint is not None:
|
if arg.node.hint is not None:
|
||||||
used_syms[sym_name] = arg.node.hint
|
used_syms[sym_name] = arg.node.hint
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
elif isinstance(arg, torch.Tensor):
|
elif isinstance(arg, torch.Tensor):
|
||||||
# Extract symbolic variables from tensor shapes and strides
|
# Extract symbolic variables from tensor shapes and strides
|
||||||
for dim in arg.shape:
|
for dim in arg.shape:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if isinstance(dim, torch.SymInt) and dim.node.hint is not None:
|
if isinstance(dim, torch.SymInt) and dim.node.hint is not None:
|
||||||
used_syms[str(dim.node)] = dim.node.hint
|
used_syms[str(dim.node)] = dim.node.hint
|
||||||
for stride in arg.stride():
|
for stride in arg.stride():
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if isinstance(stride, torch.SymInt) and stride.node.hint is not None:
|
if isinstance(stride, torch.SymInt) and stride.node.hint is not None:
|
||||||
used_syms[str(stride.node)] = stride.node.hint
|
used_syms[str(stride.node)] = stride.node.hint
|
||||||
|
|
||||||
@ -758,6 +771,7 @@ def repro_common(
|
|||||||
# TODO: speed this up
|
# TODO: speed this up
|
||||||
mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args)
|
mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
torch._inductor.config.generate_intermediate_hooks = True
|
torch._inductor.config.generate_intermediate_hooks = True
|
||||||
|
|
||||||
return mod, args
|
return mod, args
|
||||||
|
@ -301,6 +301,7 @@ def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]:
|
|||||||
def repro_common(
|
def repro_common(
|
||||||
options: Any, exported_program: ExportedProgram
|
options: Any, exported_program: ExportedProgram
|
||||||
) -> tuple[torch.fx.GraphModule, Any, Any]:
|
) -> tuple[torch.fx.GraphModule, Any, Any]:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
torch._inductor.config.generate_intermediate_hooks = True
|
torch._inductor.config.generate_intermediate_hooks = True
|
||||||
mod = exported_program.module(check_guards=False)
|
mod = exported_program.module(check_guards=False)
|
||||||
args, kwargs = exported_program.example_inputs
|
args, kwargs = exported_program.example_inputs
|
||||||
@ -422,6 +423,7 @@ def repro_minify(
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
# Need to export first so the in_spec and out_spec are populated
|
# Need to export first so the in_spec and out_spec are populated
|
||||||
tuple_inputs = tuple(flat_example_inputs)
|
tuple_inputs = tuple(flat_example_inputs)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
gm = export_for_aoti_minifier(
|
gm = export_for_aoti_minifier(
|
||||||
gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error
|
gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error
|
||||||
)
|
)
|
||||||
|
@ -102,6 +102,7 @@ def _bytecode_from_template_with_split(
|
|||||||
def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None:
|
def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None:
|
||||||
# NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
|
# NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
|
||||||
# on torch._dynamo.utils.
|
# on torch._dynamo.utils.
|
||||||
|
# pyrefly: ignore # unknown-name
|
||||||
global __import_torch_dot__dynamo_dot_utils
|
global __import_torch_dot__dynamo_dot_utils
|
||||||
try:
|
try:
|
||||||
dummy
|
dummy
|
||||||
@ -555,6 +556,7 @@ class ContinueExecutionCache:
|
|||||||
|
|
||||||
# remap original instructions' exception table entries
|
# remap original instructions' exception table entries
|
||||||
if old_hook_target_remap:
|
if old_hook_target_remap:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
assert is_py311_plus
|
assert is_py311_plus
|
||||||
for inst in instructions:
|
for inst in instructions:
|
||||||
if (
|
if (
|
||||||
|
@ -696,6 +696,7 @@ class SideEffects:
|
|||||||
cg.add_cache(var)
|
cg.add_cache(var)
|
||||||
var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
|
var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
|
||||||
elif var.source is None:
|
elif var.source is None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
var.source = LocalCellSource(var.local_name)
|
var.source = LocalCellSource(var.local_name)
|
||||||
elif isinstance(var, variables.TensorVariable):
|
elif isinstance(var, variables.TensorVariable):
|
||||||
# NOTE: for historical reasons we never assigned local sources
|
# NOTE: for historical reasons we never assigned local sources
|
||||||
@ -732,6 +733,7 @@ class SideEffects:
|
|||||||
if isinstance(var, variables.UserDefinedObjectVariable):
|
if isinstance(var, variables.UserDefinedObjectVariable):
|
||||||
|
|
||||||
def load_new_method() -> None:
|
def load_new_method() -> None:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
assert var.base_cls_vt is not None
|
assert var.base_cls_vt is not None
|
||||||
cg(var.base_cls_vt) # type: ignore[attr-defined]
|
cg(var.base_cls_vt) # type: ignore[attr-defined]
|
||||||
cg.extend_output([cg.create_load_attr("__new__")])
|
cg.extend_output([cg.create_load_attr("__new__")])
|
||||||
@ -978,7 +980,9 @@ class SideEffects:
|
|||||||
|
|
||||||
elif self.is_attribute_mutation(var):
|
elif self.is_attribute_mutation(var):
|
||||||
if isinstance(
|
if isinstance(
|
||||||
var, variables.UserDefinedDictVariable
|
var,
|
||||||
|
variables.UserDefinedDictVariable,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
) and self.is_modified(var._dict_vt):
|
) and self.is_modified(var._dict_vt):
|
||||||
# Do dict related update manually here. The store_attr
|
# Do dict related update manually here. The store_attr
|
||||||
# mutations will be applied later.
|
# mutations will be applied later.
|
||||||
@ -1011,6 +1015,7 @@ class SideEffects:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
cg(var._dict_vt, allow_cache=False) # Don't codegen via source
|
cg(var._dict_vt, allow_cache=False) # Don't codegen via source
|
||||||
cg.extend_output(
|
cg.extend_output(
|
||||||
[
|
[
|
||||||
@ -1031,7 +1036,9 @@ class SideEffects:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
var, variables.UserDefinedListVariable
|
var,
|
||||||
|
variables.UserDefinedListVariable,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
) and self.is_modified(var._list_vt):
|
) and self.is_modified(var._list_vt):
|
||||||
# Update the list to the updated items. Be careful in
|
# Update the list to the updated items. Be careful in
|
||||||
# calling the list methods and not the overridden methods.
|
# calling the list methods and not the overridden methods.
|
||||||
@ -1048,6 +1055,7 @@ class SideEffects:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
cg(var._list_vt, allow_cache=False) # Don't codegen via source
|
cg(var._list_vt, allow_cache=False) # Don't codegen via source
|
||||||
cg.extend_output(
|
cg.extend_output(
|
||||||
[
|
[
|
||||||
|
@ -563,6 +563,7 @@ def log_graph_break(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment]
|
user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment]
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
user_stack = collapse_resume_frames(user_stack)
|
user_stack = collapse_resume_frames(user_stack)
|
||||||
user_stack_formatted = "".join(traceback.format_list(user_stack))
|
user_stack_formatted = "".join(traceback.format_list(user_stack))
|
||||||
user_stack_trace = (
|
user_stack_trace = (
|
||||||
@ -1040,6 +1041,7 @@ class BytecodeDispatchTableMeta(type):
|
|||||||
op: getattr(cls, opname, functools.partial(_missing, opname))
|
op: getattr(cls, opname, functools.partial(_missing, opname))
|
||||||
for opname, op in dis.opmap.items()
|
for opname, op in dis.opmap.items()
|
||||||
}
|
}
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
|
cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
|
||||||
|
|
||||||
|
|
||||||
@ -1788,13 +1790,17 @@ class InstructionTranslatorBase(
|
|||||||
source = self.import_source(module_name)
|
source = self.import_source(module_name)
|
||||||
|
|
||||||
if self.exec_recorder:
|
if self.exec_recorder:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
self.exec_recorder.add_local_mod(recorded_name, value)
|
self.exec_recorder.add_local_mod(recorded_name, value)
|
||||||
|
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if istype(value, (types.ModuleType, DummyModule)):
|
if istype(value, (types.ModuleType, DummyModule)):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
self.push(PythonModuleVariable(value, source=source))
|
self.push(PythonModuleVariable(value, source=source))
|
||||||
else:
|
else:
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="Bad import result",
|
gb_type="Bad import result",
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
context=typestr(value),
|
context=typestr(value),
|
||||||
explanation="Import result is not a Python module.",
|
explanation="Import result is not a Python module.",
|
||||||
hints=[],
|
hints=[],
|
||||||
@ -1873,6 +1879,7 @@ class InstructionTranslatorBase(
|
|||||||
exit, exc = self.popn(2)
|
exit, exc = self.popn(2)
|
||||||
assert exc is None
|
assert exc is None
|
||||||
self.push(exc)
|
self.push(exc)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {}))
|
self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {}))
|
||||||
|
|
||||||
def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None:
|
def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None:
|
||||||
@ -2294,7 +2301,9 @@ class InstructionTranslatorBase(
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
|
elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
|
||||||
exc_instance.fn, expected_type.fn
|
exc_instance.fn,
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
|
expected_type.fn,
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -2354,26 +2363,37 @@ class InstructionTranslatorBase(
|
|||||||
assert isinstance(null, NullVariable)
|
assert isinstance(null, NullVariable)
|
||||||
|
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
argsvars, BaseListVariable
|
# pyrefly: ignore # unbound-name
|
||||||
|
argsvars,
|
||||||
|
BaseListVariable,
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
) and argsvars.has_force_unpack_var_sequence(self):
|
) and argsvars.has_force_unpack_var_sequence(self):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
|
argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
|
||||||
|
|
||||||
# Unpack for cases like fn(**obj) where obj is a map
|
# Unpack for cases like fn(**obj) where obj is a map
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if isinstance(kwargsvars, UserDefinedObjectVariable):
|
if isinstance(kwargsvars, UserDefinedObjectVariable):
|
||||||
kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type]
|
kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if not isinstance(argsvars, BaseListVariable) or not isinstance(
|
if not isinstance(argsvars, BaseListVariable) or not isinstance(
|
||||||
kwargsvars, ConstDictVariable
|
# pyrefly: ignore # unbound-name
|
||||||
|
kwargsvars,
|
||||||
|
ConstDictVariable,
|
||||||
):
|
):
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="Variadic function call with bad args/kwargs type",
|
gb_type="Variadic function call with bad args/kwargs type",
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}",
|
context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}",
|
||||||
explanation="Expected args to be a list and kwargs to be a dict",
|
explanation="Expected args to be a list and kwargs to be a dict",
|
||||||
hints=[*graph_break_hints.USER_ERROR],
|
hints=[*graph_break_hints.USER_ERROR],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map to a dictionary of str -> VariableTracker
|
# Map to a dictionary of str -> VariableTracker
|
||||||
|
# pyrefly: ignore # unbound-name, missing-attribute
|
||||||
kwargsvars = kwargsvars.keys_as_python_constant()
|
kwargsvars = kwargsvars.keys_as_python_constant()
|
||||||
|
# pyrefly: ignore # unbound-name, missing-attribute
|
||||||
self.call_function(fn, argsvars.items, kwargsvars)
|
self.call_function(fn, argsvars.items, kwargsvars)
|
||||||
|
|
||||||
@break_graph_if_unsupported(push=1)
|
@break_graph_if_unsupported(push=1)
|
||||||
@ -2437,6 +2457,7 @@ class InstructionTranslatorBase(
|
|||||||
|
|
||||||
def LOAD_ATTR(self, inst: Instruction) -> None:
|
def LOAD_ATTR(self, inst: Instruction) -> None:
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
if inst.arg % 2:
|
if inst.arg % 2:
|
||||||
self.LOAD_METHOD(inst)
|
self.LOAD_METHOD(inst)
|
||||||
return
|
return
|
||||||
@ -3029,14 +3050,17 @@ class InstructionTranslatorBase(
|
|||||||
"(i.e. `a, b, c = d`).",
|
"(i.e. `a, b, c = d`).",
|
||||||
hints=[*graph_break_hints.USER_ERROR],
|
hints=[*graph_break_hints.USER_ERROR],
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if len(val) != inst.argval:
|
if len(val) != inst.argval:
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE",
|
gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE",
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
context=f"expected length: {inst.argval}, actual: {len(val)}",
|
context=f"expected length: {inst.argval}, actual: {len(val)}",
|
||||||
explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode "
|
explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode "
|
||||||
"(i.e. `a, b, c = d`) with unexpected length.",
|
"(i.e. `a, b, c = d`) with unexpected length.",
|
||||||
hints=[*graph_break_hints.DYNAMO_BUG],
|
hints=[*graph_break_hints.DYNAMO_BUG],
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
for i in reversed(val):
|
for i in reversed(val):
|
||||||
self.push(i)
|
self.push(i)
|
||||||
|
|
||||||
@ -3409,9 +3433,13 @@ class InstructionTranslatorBase(
|
|||||||
args = [contents[1]]
|
args = [contents[1]]
|
||||||
|
|
||||||
if kw_names:
|
if kw_names:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
args = args + contents[2 : -len(kw_names)]
|
args = args + contents[2 : -len(kw_names)]
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
kwargs_list = contents[-len(kw_names) :]
|
kwargs_list = contents[-len(kw_names) :]
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
kwargs = dict(zip(kw_names, kwargs_list))
|
kwargs = dict(zip(kw_names, kwargs_list))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
assert len(kwargs) == len(kw_names)
|
assert len(kwargs) == len(kw_names)
|
||||||
else:
|
else:
|
||||||
args = args + contents[2:]
|
args = args + contents[2:]
|
||||||
@ -4118,6 +4146,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||||||
and isinstance(tos, LocalGeneratorObjectVariable)
|
and isinstance(tos, LocalGeneratorObjectVariable)
|
||||||
):
|
):
|
||||||
self.stack[-1] = ListIteratorVariable(
|
self.stack[-1] = ListIteratorVariable(
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
tos.force_unpack_var_sequence(self),
|
tos.force_unpack_var_sequence(self),
|
||||||
mutation_type=ValueMutationNew(),
|
mutation_type=ValueMutationNew(),
|
||||||
)
|
)
|
||||||
@ -4188,6 +4217,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||||||
"""Trace and inline a called method"""
|
"""Trace and inline a called method"""
|
||||||
|
|
||||||
symbolic_result: Optional[VariableTracker]
|
symbolic_result: Optional[VariableTracker]
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
parent: InstructionTranslatorBase
|
parent: InstructionTranslatorBase
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -4231,6 +4261,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||||||
# trace through.
|
# trace through.
|
||||||
if (
|
if (
|
||||||
hasattr(getattr(func, "fn", None), "_origin")
|
hasattr(getattr(func, "fn", None), "_origin")
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
and func.fn._origin is produce_trampoline_autograd_apply
|
and func.fn._origin is produce_trampoline_autograd_apply
|
||||||
):
|
):
|
||||||
# Known sound
|
# Known sound
|
||||||
@ -4305,12 +4336,14 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||||||
tracing_ctx.previously_inlined_functions[code] = result
|
tracing_ctx.previously_inlined_functions[code] = result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
sub_locals = func.bind_args(parent, args, kwargs)
|
sub_locals = func.bind_args(parent, args, kwargs)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
# Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
|
# Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
|
||||||
raise ArgsMismatchError( # noqa: B904
|
raise ArgsMismatchError( # noqa: B904
|
||||||
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
|
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
|
||||||
reason=str(e),
|
reason=str(e),
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
|
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
|
||||||
args=[arg.python_type() for arg in args],
|
args=[arg.python_type() for arg in args],
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
@ -4394,6 +4427,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||||||
sub_locals,
|
sub_locals,
|
||||||
parent.symbolic_globals,
|
parent.symbolic_globals,
|
||||||
parent.symbolic_torch_function_state,
|
parent.symbolic_torch_function_state,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
return tracer
|
return tracer
|
||||||
|
@ -153,7 +153,9 @@ class CPythonTestCase(TestCase):
|
|||||||
assertTupleEqual = unittest.TestCase.assertTupleEqual
|
assertTupleEqual = unittest.TestCase.assertTupleEqual
|
||||||
assertSetEqual = unittest.TestCase.assertSetEqual
|
assertSetEqual = unittest.TestCase.assertSetEqual
|
||||||
assertDictEqual = polyfills.assert_dict_equal
|
assertDictEqual = polyfills.assert_dict_equal
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
assertRaises = unittest.TestCase.assertRaises
|
assertRaises = unittest.TestCase.assertRaises
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
|
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
|
||||||
assertWarns = unittest.TestCase.assertWarns
|
assertWarns = unittest.TestCase.assertWarns
|
||||||
assertWarnsRegex = unittest.TestCase.assertWarnsRegex
|
assertWarnsRegex = unittest.TestCase.assertWarnsRegex
|
||||||
@ -169,8 +171,10 @@ class CPythonTestCase(TestCase):
|
|||||||
) -> Callable[..., Any]:
|
) -> Callable[..., Any]:
|
||||||
# We want to compile only the test function, excluding any setup code
|
# We want to compile only the test function, excluding any setup code
|
||||||
# from unittest
|
# from unittest
|
||||||
|
|
||||||
method = getattr(self, self._testMethodName)
|
method = getattr(self, self._testMethodName)
|
||||||
method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method)
|
method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method)
|
||||||
|
|
||||||
setattr(self, self._testMethodName, method)
|
setattr(self, self._testMethodName, method)
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
@ -207,6 +207,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
|||||||
launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py"))
|
launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py"))
|
||||||
with open(launch_file) as f:
|
with open(launch_file) as f:
|
||||||
launch_code = f.read()
|
launch_code = f.read()
|
||||||
|
|
||||||
self.assertTrue(os.path.exists(launch_file))
|
self.assertTrue(os.path.exists(launch_file))
|
||||||
|
|
||||||
args = ["python3", launch_file, "minify", *minifier_args]
|
args = ["python3", launch_file, "minify", *minifier_args]
|
||||||
@ -218,6 +219,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
|||||||
print("minifier stdout:", launch_proc.stdout.decode("utf-8"))
|
print("minifier stdout:", launch_proc.stdout.decode("utf-8"))
|
||||||
stderr = launch_proc.stderr.decode("utf-8")
|
stderr = launch_proc.stderr.decode("utf-8")
|
||||||
print("minifier stderr:", stderr)
|
print("minifier stderr:", stderr)
|
||||||
|
|
||||||
self.assertNotIn("Input graph did not fail the tester", stderr)
|
self.assertNotIn("Input graph did not fail the tester", stderr)
|
||||||
|
|
||||||
return launch_proc, launch_code
|
return launch_proc, launch_code
|
||||||
@ -230,6 +232,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
|||||||
repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py"))
|
repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py"))
|
||||||
with open(repro_file) as f:
|
with open(repro_file) as f:
|
||||||
repro_code = f.read()
|
repro_code = f.read()
|
||||||
|
|
||||||
self.assertTrue(os.path.exists(repro_file))
|
self.assertTrue(os.path.exists(repro_file))
|
||||||
|
|
||||||
repro_proc = self._maybe_subprocess_run(
|
repro_proc = self._maybe_subprocess_run(
|
||||||
@ -296,11 +299,14 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
|
|||||||
if expected_error is None:
|
if expected_error is None:
|
||||||
# Just check that there was no error
|
# Just check that there was no error
|
||||||
self.assertEqual(test_proc.returncode, 0)
|
self.assertEqual(test_proc.returncode, 0)
|
||||||
|
|
||||||
self.assertIsNone(repro_dir)
|
self.assertIsNone(repro_dir)
|
||||||
return None
|
return None
|
||||||
# NB: Intentionally do not test return code; we only care about
|
# NB: Intentionally do not test return code; we only care about
|
||||||
# actually generating the repro, we don't have to crash
|
# actually generating the repro, we don't have to crash
|
||||||
|
|
||||||
self.assertIn(expected_error, test_proc.stderr.decode("utf-8"))
|
self.assertIn(expected_error, test_proc.stderr.decode("utf-8"))
|
||||||
|
|
||||||
self.assertIsNotNone(repro_dir)
|
self.assertIsNotNone(repro_dir)
|
||||||
print("running minifier", file=sys.stderr)
|
print("running minifier", file=sys.stderr)
|
||||||
_minifier_proc, minifier_code = self._run_minifier_launcher(
|
_minifier_proc, minifier_code = self._run_minifier_launcher(
|
||||||
@ -311,6 +317,7 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
|
|||||||
)
|
)
|
||||||
print("running repro", file=sys.stderr)
|
print("running repro", file=sys.stderr)
|
||||||
repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate)
|
repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate)
|
||||||
|
|
||||||
self.assertIn(expected_error, repro_proc.stderr.decode("utf-8"))
|
self.assertIn(expected_error, repro_proc.stderr.decode("utf-8"))
|
||||||
self.assertNotEqual(repro_proc.returncode, 0)
|
self.assertNotEqual(repro_proc.returncode, 0)
|
||||||
return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code)
|
return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code)
|
||||||
|
@ -496,6 +496,7 @@ def make_test_cls_with_patches(
|
|||||||
def skipIfNotPy311(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
def skipIfNotPy311(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
return fn
|
return fn
|
||||||
|
# pyrefly: ignore # bad-return, bad-argument-type
|
||||||
return unittest.skip(fn)
|
return unittest.skip(fn)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3005,6 +3005,7 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
|
|||||||
obj = torch_dir + k[len("torch/") :]
|
obj = torch_dir + k[len("torch/") :]
|
||||||
if obj is not None:
|
if obj is not None:
|
||||||
if is_annotate_wrapped_function(obj):
|
if is_annotate_wrapped_function(obj):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
obj = obj.__wrapped__
|
obj = obj.__wrapped__
|
||||||
if is_lru_cache_wrapped_function(obj):
|
if is_lru_cache_wrapped_function(obj):
|
||||||
obj = obj.__wrapped__
|
obj = obj.__wrapped__
|
||||||
|
@ -295,11 +295,13 @@ def increment_op_count(cnt: int) -> None:
|
|||||||
def calculate_time_spent() -> dict[str, float]:
|
def calculate_time_spent() -> dict[str, float]:
|
||||||
total_by_key = {}
|
total_by_key = {}
|
||||||
for phase, timing in cumulative_time_spent_ns.items():
|
for phase, timing in cumulative_time_spent_ns.items():
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
total_by_key[phase] = timing / 1e9
|
total_by_key[phase] = timing / 1e9
|
||||||
|
|
||||||
total_by_key["total_wall_time"] = total_by_key.get(
|
total_by_key["total_wall_time"] = total_by_key.get(
|
||||||
"entire_frame_compile", 0
|
"entire_frame_compile", 0
|
||||||
) + total_by_key.get("entire_backward_compile", 0)
|
) + total_by_key.get("entire_backward_compile", 0)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return total_by_key
|
return total_by_key
|
||||||
|
|
||||||
|
|
||||||
@ -798,6 +800,7 @@ def compile_times(repr: Literal["str"], aggregate: bool = False) -> str: ...
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def compile_times(
|
def compile_times(
|
||||||
repr: Literal["csv"], aggregate: bool = False
|
repr: Literal["csv"], aggregate: bool = False
|
||||||
) -> tuple[list[str], list[object]]: ...
|
) -> tuple[list[str], list[object]]: ...
|
||||||
@ -1463,6 +1466,7 @@ class CompilationMetrics:
|
|||||||
compile_id = all_metrics.get("compile_id")
|
compile_id = all_metrics.get("compile_id")
|
||||||
all_metrics["compile_id"] = str(compile_id) if compile_id else None
|
all_metrics["compile_id"] = str(compile_id) if compile_id else None
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return cls(**all_metrics)
|
return cls(**all_metrics)
|
||||||
|
|
||||||
|
|
||||||
@ -2253,6 +2257,7 @@ def is_jit_model(
|
|||||||
Union[
|
Union[
|
||||||
torch.jit._trace.TopLevelTracedModule,
|
torch.jit._trace.TopLevelTracedModule,
|
||||||
torch.jit._script.RecursiveScriptModule,
|
torch.jit._script.RecursiveScriptModule,
|
||||||
|
# pyrefly: ignore # invalid-param-spec
|
||||||
torch.jit.ScriptFunction[Any, Any],
|
torch.jit.ScriptFunction[Any, Any],
|
||||||
torch.jit.ScriptModule,
|
torch.jit.ScriptModule,
|
||||||
]
|
]
|
||||||
@ -2361,6 +2366,7 @@ def checkpoint_params(gm: torch.fx.GraphModule) -> Callable[[], None]:
|
|||||||
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
|
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
|
||||||
saved_state = [
|
saved_state = [
|
||||||
(param, param._version, torch.clone(param))
|
(param, param._version, torch.clone(param))
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
for param in itertools.chain(gm.parameters(), gm.buffers())
|
for param in itertools.chain(gm.parameters(), gm.buffers())
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2626,13 +2632,16 @@ def get_items_from_dict(obj: dict[K, V]) -> Iterable[tuple[K, Union[V, Any]]]:
|
|||||||
if istype(obj, (dict, OrderedDict)):
|
if istype(obj, (dict, OrderedDict)):
|
||||||
return obj.items()
|
return obj.items()
|
||||||
elif isinstance(obj, OrderedDict):
|
elif isinstance(obj, OrderedDict):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return [(k, OrderedDict.__getitem__(obj, k)) for k in OrderedDict.keys(obj)]
|
return [(k, OrderedDict.__getitem__(obj, k)) for k in OrderedDict.keys(obj)]
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)]
|
return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)]
|
||||||
|
|
||||||
|
|
||||||
def nn_module_new(cls: Any) -> Any:
|
def nn_module_new(cls: Any) -> Any:
|
||||||
obj = object_new(cls)
|
obj = object_new(cls)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
torch.nn.Module.__init__(obj)
|
torch.nn.Module.__init__(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@ -2679,6 +2688,7 @@ def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any:
|
|||||||
dict_class = dict
|
dict_class = dict
|
||||||
if isinstance(d, OrderedDict):
|
if isinstance(d, OrderedDict):
|
||||||
dict_class = OrderedDict
|
dict_class = OrderedDict
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return next(itertools.islice(dict_class.keys(d), n, n + 1))
|
return next(itertools.islice(dict_class.keys(d), n, n + 1))
|
||||||
|
|
||||||
|
|
||||||
@ -3222,8 +3232,10 @@ def format_func_info(code: CodeType) -> str:
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def disable_cache_limit() -> Generator[None, None, None]:
|
def disable_cache_limit() -> Generator[None, None, None]:
|
||||||
prior = config.recompile_limit
|
prior = config.recompile_limit
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
config.recompile_limit = sys.maxsize
|
config.recompile_limit = sys.maxsize
|
||||||
prior_acc_limit = config.accumulated_recompile_limit
|
prior_acc_limit = config.accumulated_recompile_limit
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
config.accumulated_recompile_limit = sys.maxsize
|
config.accumulated_recompile_limit = sys.maxsize
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -3958,6 +3970,7 @@ class numpy_operator_wrapper(Generic[_P, R]):
|
|||||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any:
|
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any:
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
args = (
|
args = (
|
||||||
tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
|
tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
|
||||||
)
|
)
|
||||||
@ -4157,6 +4170,7 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]:
|
|||||||
# (x) + (y)
|
# (x) + (y)
|
||||||
# ~~^~~~~~~
|
# ~~^~~~~~~
|
||||||
while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
|
while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if ch in "\\#":
|
if ch in "\\#":
|
||||||
cur_lineno, cur_col = nextline(cur_lineno, cur_col)
|
cur_lineno, cur_col = nextline(cur_lineno, cur_col)
|
||||||
else:
|
else:
|
||||||
@ -4507,6 +4521,7 @@ class GmWrapper(torch.nn.Module):
|
|||||||
self.unflatten_fn = unflatten_fn
|
self.unflatten_fn = unflatten_fn
|
||||||
|
|
||||||
def forward(self, *args: Any) -> Any:
|
def forward(self, *args: Any) -> Any:
|
||||||
|
# pyrefly: ignore # annotation-mismatch
|
||||||
args: list[Any] = list(args)
|
args: list[Any] = list(args)
|
||||||
return self.gm(*self.unflatten_fn(args))
|
return self.gm(*self.unflatten_fn(args))
|
||||||
|
|
||||||
|
@ -1028,6 +1028,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
|
|
||||||
def call_self_handler(tx: "InstructionTranslator", args, kwargs):
|
def call_self_handler(tx: "InstructionTranslator", args, kwargs):
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore # not-callable
|
||||||
result = self_handler(tx, *args, **kwargs)
|
result = self_handler(tx, *args, **kwargs)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
@ -1035,6 +1036,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
# Check if binding is bad. inspect signature bind is expensive.
|
# Check if binding is bad. inspect signature bind is expensive.
|
||||||
# So check only when handler call fails.
|
# So check only when handler call fails.
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
inspect.signature(self_handler).bind(tx, *args, **kwargs)
|
inspect.signature(self_handler).bind(tx, *args, **kwargs)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
has_constant_handler = obj.has_constant_handler(args, kwargs)
|
has_constant_handler = obj.has_constant_handler(args, kwargs)
|
||||||
@ -1087,6 +1089,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
hints=[*graph_break_hints.DYNAMO_BUG],
|
hints=[*graph_break_hints.DYNAMO_BUG],
|
||||||
from_exc=exc,
|
from_exc=exc,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
return VariableTracker.build(tx, res)
|
return VariableTracker.build(tx, res)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -1115,6 +1118,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
tx,
|
tx,
|
||||||
args=list(map(ConstantVariable.create, exc.args)),
|
args=list(map(ConstantVariable.create, exc.args)),
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
return VariableTracker.build(tx, res)
|
return VariableTracker.build(tx, res)
|
||||||
|
|
||||||
handlers.append(constant_fold_handler)
|
handlers.append(constant_fold_handler)
|
||||||
@ -1437,6 +1441,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
resolved_fn = getattr(self.fn, name)
|
resolved_fn = getattr(self.fn, name)
|
||||||
if resolved_fn in dict_methods:
|
if resolved_fn in dict_methods:
|
||||||
if isinstance(args[0], variables.UserDefinedDictVariable):
|
if isinstance(args[0], variables.UserDefinedDictVariable):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
|
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
|
||||||
elif isinstance(args[0], variables.ConstDictVariable):
|
elif isinstance(args[0], variables.ConstDictVariable):
|
||||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||||
@ -1445,6 +1450,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
resolved_fn = getattr(self.fn, name)
|
resolved_fn = getattr(self.fn, name)
|
||||||
if resolved_fn in set_methods:
|
if resolved_fn in set_methods:
|
||||||
if isinstance(args[0], variables.UserDefinedSetVariable):
|
if isinstance(args[0], variables.UserDefinedSetVariable):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
|
return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
|
||||||
elif isinstance(args[0], variables.SetVariable):
|
elif isinstance(args[0], variables.SetVariable):
|
||||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||||
@ -1533,10 +1539,12 @@ class BuiltinVariable(VariableTracker):
|
|||||||
if type(arg.value).__str__ is object.__str__:
|
if type(arg.value).__str__ is object.__str__:
|
||||||
# Rely on the object str method
|
# Rely on the object str method
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
return variables.ConstantVariable.create(value=str_method())
|
return variables.ConstantVariable.create(value=str_method())
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# Graph break
|
# Graph break
|
||||||
return
|
return
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
elif is_wrapper_or_member_descriptor(str_method):
|
elif is_wrapper_or_member_descriptor(str_method):
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="Attempted to a str() method implemented in C/C++",
|
gb_type="Attempted to a str() method implemented in C/C++",
|
||||||
@ -1653,8 +1661,10 @@ class BuiltinVariable(VariableTracker):
|
|||||||
else:
|
else:
|
||||||
raw_b = b.raw_value
|
raw_b = b.raw_value
|
||||||
if self.fn is max:
|
if self.fn is max:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
raw_res = max(a.raw_value, raw_b)
|
raw_res = max(a.raw_value, raw_b)
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
raw_res = min(a.raw_value, raw_b)
|
raw_res = min(a.raw_value, raw_b)
|
||||||
|
|
||||||
need_unwrap = any(
|
need_unwrap = any(
|
||||||
@ -2106,6 +2116,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(arg, variables.UserDefinedExceptionClassVariable):
|
if isinstance(arg, variables.UserDefinedExceptionClassVariable):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
return ConstantVariable.create(isinstance(arg_type, isinstance_type))
|
return ConstantVariable.create(isinstance(arg_type, isinstance_type))
|
||||||
|
|
||||||
isinstance_type_tuple: tuple[type, ...]
|
isinstance_type_tuple: tuple[type, ...]
|
||||||
@ -2138,8 +2149,10 @@ class BuiltinVariable(VariableTracker):
|
|||||||
# through it. This is a limitation of the current implementation.
|
# through it. This is a limitation of the current implementation.
|
||||||
# Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
|
# Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
|
||||||
# might not be a big issue and we trade off it for performance.
|
# might not be a big issue and we trade off it for performance.
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
val = issubclass(arg_type, isinstance_type_tuple)
|
val = issubclass(arg_type, isinstance_type_tuple)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
val = arg_type in isinstance_type_tuple
|
val = arg_type in isinstance_type_tuple
|
||||||
return variables.ConstantVariable.create(val)
|
return variables.ConstantVariable.create(val)
|
||||||
|
|
||||||
@ -2161,6 +2174,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
|
|
||||||
# WARNING: This might run arbitrary user code `__subclasscheck__`.
|
# WARNING: This might run arbitrary user code `__subclasscheck__`.
|
||||||
# See the comment in call_isinstance above.
|
# See the comment in call_isinstance above.
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
|
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
|
||||||
|
|
||||||
def call_super(self, tx: "InstructionTranslator", a, b):
|
def call_super(self, tx: "InstructionTranslator", a, b):
|
||||||
@ -2206,7 +2220,9 @@ class BuiltinVariable(VariableTracker):
|
|||||||
value = getattr(self.fn, name)
|
value = getattr(self.fn, name)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise_observed_exception(AttributeError, tx)
|
raise_observed_exception(AttributeError, tx)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if not callable(value):
|
if not callable(value):
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
return VariableTracker.build(tx, value, source)
|
return VariableTracker.build(tx, value, source)
|
||||||
return variables.GetAttrVariable(self, name, source=source)
|
return variables.GetAttrVariable(self, name, source=source)
|
||||||
|
|
||||||
|
@ -651,6 +651,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
def handle_use_deterministic_algorithms(
|
def handle_use_deterministic_algorithms(
|
||||||
self, tx: "InstructionTranslator", mode, warn_only=False
|
self, tx: "InstructionTranslator", mode, warn_only=False
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if warn_only and warn_only.as_python_constant():
|
if warn_only and warn_only.as_python_constant():
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)",
|
gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)",
|
||||||
@ -1035,6 +1036,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
else:
|
else:
|
||||||
raise torch._dynamo.exc.Unsupported("branch not supported")
|
raise torch._dynamo.exc.Unsupported("branch not supported")
|
||||||
return variables.ConstantVariable.create(
|
return variables.ConstantVariable.create(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
torch.fx.experimental.symbolic_shapes.guard_scalar(val)
|
torch.fx.experimental.symbolic_shapes.guard_scalar(val)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1081,6 +1083,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
return
|
return
|
||||||
|
|
||||||
return variables.ConstantVariable.create(
|
return variables.ConstantVariable.create(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
torch.fx.experimental.symbolic_shapes.has_static_value(val)
|
torch.fx.experimental.symbolic_shapes.has_static_value(val)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1212,13 +1215,17 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# need to guard only on no-arg get_device_module
|
# need to guard only on no-arg get_device_module
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
if device is None:
|
if device is None:
|
||||||
source = CallFunctionNoArgsSource(self.source)
|
source = CallFunctionNoArgsSource(self.source)
|
||||||
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
|
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
|
||||||
# assumes `module` is in the form `torch.xyz`
|
# assumes `module` is in the form `torch.xyz`
|
||||||
new_source = AttrSource(
|
new_source = AttrSource(
|
||||||
TorchSource(), module.__name__.rsplit(".", maxsplit=1)[-1]
|
TorchSource(),
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
|
module.__name__.rsplit(".", maxsplit=1)[-1],
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
return VariableTracker.build(tx, module, new_source)
|
return VariableTracker.build(tx, module, new_source)
|
||||||
|
|
||||||
@register(torch.set_default_device)
|
@register(torch.set_default_device)
|
||||||
@ -1373,9 +1380,12 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
f"{fn.__name__}_spec", f_spec
|
f"{fn.__name__}_spec", f_spec
|
||||||
)
|
)
|
||||||
input_spec_proxy = tx.output.register_static_attr_and_return_proxy(
|
input_spec_proxy = tx.output.register_static_attr_and_return_proxy(
|
||||||
fn.__name__ + "_input_spec", input_spec
|
fn.__name__ + "_input_spec",
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
|
input_spec,
|
||||||
)
|
)
|
||||||
f_spec_proxy.node.type = type(f_spec)
|
f_spec_proxy.node.type = type(f_spec)
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
input_spec_proxy.node.type = type(input_spec)
|
input_spec_proxy.node.type = type(input_spec)
|
||||||
all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args)
|
all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args)
|
||||||
|
|
||||||
@ -1716,6 +1726,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
|||||||
)
|
)
|
||||||
|
|
||||||
# this results in cleaner graphs, but only works for inputs
|
# this results in cleaner graphs, but only works for inputs
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if data.source:
|
if data.source:
|
||||||
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
|
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
|
||||||
|
|
||||||
@ -1734,7 +1745,9 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
|||||||
|
|
||||||
# TODO[@lucaskabela]: Remove the behavior below since it is deprecated
|
# TODO[@lucaskabela]: Remove the behavior below since it is deprecated
|
||||||
if isinstance(
|
if isinstance(
|
||||||
data, TensorWithTFOverrideVariable
|
data,
|
||||||
|
TensorWithTFOverrideVariable,
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
) or is_traceable_wrapper_subclass_type(data.class_type):
|
) or is_traceable_wrapper_subclass_type(data.class_type):
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass",
|
gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass",
|
||||||
@ -1757,8 +1770,11 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
|
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
dtype = data.var_getattr(tx, "dtype").as_python_constant()
|
dtype = data.var_getattr(tx, "dtype").as_python_constant()
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
device = data.var_getattr(tx, "device").as_python_constant()
|
device = data.var_getattr(tx, "device").as_python_constant()
|
||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
@ -1773,9 +1789,13 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
|||||||
)
|
)
|
||||||
|
|
||||||
placeholder = tx.output.synthetic_graph_input(
|
placeholder = tx.output.synthetic_graph_input(
|
||||||
new_parameter_placeholder, [shape, dtype, device, requires_grad]
|
new_parameter_placeholder,
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
|
[shape, dtype, device, requires_grad],
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if data.requires_grad:
|
if data.requires_grad:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
data = data.call_method(tx, "detach", [], {})
|
data = data.call_method(tx, "detach", [], {})
|
||||||
|
|
||||||
from .builder import wrap_fx_proxy
|
from .builder import wrap_fx_proxy
|
||||||
@ -1785,6 +1805,7 @@ For now, dynamo will explicitly graph break when it encounters user code with th
|
|||||||
tx.output.create_proxy(
|
tx.output.create_proxy(
|
||||||
"call_function",
|
"call_function",
|
||||||
tracable_create_parameter,
|
tracable_create_parameter,
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
(data.as_proxy(), placeholder.as_proxy()),
|
(data.as_proxy(), placeholder.as_proxy()),
|
||||||
{},
|
{},
|
||||||
),
|
),
|
||||||
|
@ -646,7 +646,6 @@ def update_schema():
|
|||||||
assert thrift_content[1].startswith("// checksum<<")
|
assert thrift_content[1].startswith("// checksum<<")
|
||||||
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
|
thrift_checksum_real = _hash_content("\n".join(thrift_content[2:]))
|
||||||
|
|
||||||
# pyrefly: ignore # import-error
|
|
||||||
from yaml import load, Loader
|
from yaml import load, Loader
|
||||||
|
|
||||||
dst = load(content, Loader=Loader)
|
dst = load(content, Loader=Loader)
|
||||||
|
@ -34,7 +34,7 @@ from torch.fx._pytree import (
|
|||||||
_deregister_pytree_flatten_spec,
|
_deregister_pytree_flatten_spec,
|
||||||
register_pytree_flatten_spec,
|
register_pytree_flatten_spec,
|
||||||
)
|
)
|
||||||
from torch.utils._pytree import ( # pyrefly: ignore # deprecated
|
from torch.utils._pytree import (
|
||||||
_deregister_pytree_node,
|
_deregister_pytree_node,
|
||||||
_register_pytree_node,
|
_register_pytree_node,
|
||||||
Context,
|
Context,
|
||||||
|
@ -198,7 +198,7 @@ def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str:
|
|||||||
to a file. The yaml string can be loaded back into an operator profile
|
to a file. The yaml string can be loaded back into an operator profile
|
||||||
structure using `read_profiles_from_yaml`.
|
structure using `read_profiles_from_yaml`.
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # import-error
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from torch._export.serde.serialize import (
|
from torch._export.serde.serialize import (
|
||||||
@ -263,7 +263,7 @@ def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]:
|
|||||||
"""
|
"""
|
||||||
Reads the yaml saved by `save_op_profiles` and returns the operator profiles.
|
Reads the yaml saved by `save_op_profiles` and returns the operator profiles.
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # import-error
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from torch._export.serde.serialize import (
|
from torch._export.serde.serialize import (
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .core import dispatch
|
from .core import dispatch
|
||||||
from .dispatcher import ( # pyrefly: ignore # deprecated
|
from .dispatcher import (
|
||||||
Dispatcher,
|
Dispatcher,
|
||||||
halt_ordering,
|
halt_ordering,
|
||||||
MDNotImplementedError,
|
MDNotImplementedError,
|
||||||
|
@ -153,6 +153,7 @@ class CausalBias(torch.Tensor):
|
|||||||
diagonal=diagonal_offset,
|
diagonal=diagonal_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
|
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Materializes the causal bias into a tensor form.
|
Materializes the causal bias into a tensor form.
|
||||||
|
@ -84,6 +84,7 @@ _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor
|
|||||||
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
|
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
class FlexKernelOptions(TypedDict, total=False):
|
class FlexKernelOptions(TypedDict, total=False):
|
||||||
"""Options for controlling the behavior of FlexAttention kernels.
|
"""Options for controlling the behavior of FlexAttention kernels.
|
||||||
|
|
||||||
@ -127,76 +128,93 @@ class FlexKernelOptions(TypedDict, total=False):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Performance tuning options
|
# Performance tuning options
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
num_warps: NotRequired[int]
|
num_warps: NotRequired[int]
|
||||||
"""Number of warps to use in the CUDA kernel. Higher values may improve performance
|
"""Number of warps to use in the CUDA kernel. Higher values may improve performance
|
||||||
but increase register pressure. Default is determined by autotuning."""
|
but increase register pressure. Default is determined by autotuning."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
num_stages: NotRequired[int]
|
num_stages: NotRequired[int]
|
||||||
"""Number of pipeline stages in the CUDA kernel. Higher values may improve performance
|
"""Number of pipeline stages in the CUDA kernel. Higher values may improve performance
|
||||||
but increase shared memory usage. Default is determined by autotuning."""
|
but increase shared memory usage. Default is determined by autotuning."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
BLOCK_M: NotRequired[int]
|
BLOCK_M: NotRequired[int]
|
||||||
"""Thread block size for the sequence length dimension of Q in forward pass.
|
"""Thread block size for the sequence length dimension of Q in forward pass.
|
||||||
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
|
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
BLOCK_N: NotRequired[int]
|
BLOCK_N: NotRequired[int]
|
||||||
"""Thread block size for the sequence length dimension of K/V in forward pass.
|
"""Thread block size for the sequence length dimension of K/V in forward pass.
|
||||||
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
|
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
|
||||||
|
|
||||||
# Backward-specific block sizes (when prefixed with 'bwd_')
|
# Backward-specific block sizes (when prefixed with 'bwd_')
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
BLOCK_M1: NotRequired[int]
|
BLOCK_M1: NotRequired[int]
|
||||||
"""Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'.
|
"""Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'.
|
||||||
Default is determined by autotuning."""
|
Default is determined by autotuning."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
BLOCK_N1: NotRequired[int]
|
BLOCK_N1: NotRequired[int]
|
||||||
"""Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'.
|
"""Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'.
|
||||||
Default is determined by autotuning."""
|
Default is determined by autotuning."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
BLOCK_M2: NotRequired[int]
|
BLOCK_M2: NotRequired[int]
|
||||||
"""Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'.
|
"""Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'.
|
||||||
Default is determined by autotuning."""
|
Default is determined by autotuning."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
BLOCK_N2: NotRequired[int]
|
BLOCK_N2: NotRequired[int]
|
||||||
"""Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'.
|
"""Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'.
|
||||||
Default is determined by autotuning."""
|
Default is determined by autotuning."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
PRESCALE_QK: NotRequired[bool]
|
PRESCALE_QK: NotRequired[bool]
|
||||||
"""Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but
|
"""Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but
|
||||||
may have more numerical error. Default: False."""
|
may have more numerical error. Default: False."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
ROWS_GUARANTEED_SAFE: NotRequired[bool]
|
ROWS_GUARANTEED_SAFE: NotRequired[bool]
|
||||||
"""If True, guarantees that at least one value in each row is not masked out.
|
"""If True, guarantees that at least one value in each row is not masked out.
|
||||||
Allows skipping safety checks for better performance. Only set this if you are certain
|
Allows skipping safety checks for better performance. Only set this if you are certain
|
||||||
your mask guarantees this property. For example, causal attention is guaranteed safe
|
your mask guarantees this property. For example, causal attention is guaranteed safe
|
||||||
because each query has at least 1 key-value to attend to. Default: False."""
|
because each query has at least 1 key-value to attend to. Default: False."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]
|
BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]
|
||||||
"""If True, guarantees that all blocks in the mask are contiguous.
|
"""If True, guarantees that all blocks in the mask are contiguous.
|
||||||
Allows optimizing block traversal. For example, causal masks would satisfy this,
|
Allows optimizing block traversal. For example, causal masks would satisfy this,
|
||||||
but prefix_lm + sliding window would not. Default: False."""
|
but prefix_lm + sliding window would not. Default: False."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
WRITE_DQ: NotRequired[bool]
|
WRITE_DQ: NotRequired[bool]
|
||||||
"""Controls whether gradient scatters are done in the DQ iteration loop of the backward pass.
|
"""Controls whether gradient scatters are done in the DQ iteration loop of the backward pass.
|
||||||
Setting this to False will force this to happen in the DK loop which depending on your
|
Setting this to False will force this to happen in the DK loop which depending on your
|
||||||
specific score_mod and mask_mod might be faster. Default: True."""
|
specific score_mod and mask_mod might be faster. Default: True."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
FORCE_USE_FLEX_ATTENTION: NotRequired[bool]
|
FORCE_USE_FLEX_ATTENTION: NotRequired[bool]
|
||||||
"""If True, forces the use of the flex attention kernel instead of potentially using
|
"""If True, forces the use of the flex attention kernel instead of potentially using
|
||||||
the more optimized flex-decoding kernel for short sequences. This can be a helpful
|
the more optimized flex-decoding kernel for short sequences. This can be a helpful
|
||||||
option for debugging. Default: False."""
|
option for debugging. Default: False."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
USE_TMA: NotRequired[bool]
|
USE_TMA: NotRequired[bool]
|
||||||
"""Whether to use Tensor Memory Accelerator (TMA) on supported hardware.
|
"""Whether to use Tensor Memory Accelerator (TMA) on supported hardware.
|
||||||
This is experimental and may not work on all hardware, currently specific
|
This is experimental and may not work on all hardware, currently specific
|
||||||
to NVIDIA GPUs Hopper+. Default: False."""
|
to NVIDIA GPUs Hopper+. Default: False."""
|
||||||
|
|
||||||
# ROCm-specific options
|
# ROCm-specific options
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
kpack: NotRequired[int]
|
kpack: NotRequired[int]
|
||||||
"""ROCm-specific kernel packing parameter."""
|
"""ROCm-specific kernel packing parameter."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
matrix_instr_nonkdim: NotRequired[int]
|
matrix_instr_nonkdim: NotRequired[int]
|
||||||
"""ROCm-specific matrix instruction non-K dimension."""
|
"""ROCm-specific matrix instruction non-K dimension."""
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-annotation
|
||||||
waves_per_eu: NotRequired[int]
|
waves_per_eu: NotRequired[int]
|
||||||
"""ROCm-specific waves per execution unit."""
|
"""ROCm-specific waves per execution unit."""
|
||||||
|
|
||||||
@ -581,6 +599,7 @@ class BlockMask:
|
|||||||
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
|
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
|
||||||
seq_lengths = (self.seq_lengths,) # type: ignore[assignment]
|
seq_lengths = (self.seq_lengths,) # type: ignore[assignment]
|
||||||
|
|
||||||
|
# pyrefly: ignore # not-iterable
|
||||||
return (
|
return (
|
||||||
*seq_lengths,
|
*seq_lengths,
|
||||||
self.kv_num_blocks,
|
self.kv_num_blocks,
|
||||||
@ -753,6 +772,7 @@ class BlockMask:
|
|||||||
partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
|
partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
|
||||||
if self.full_kv_num_blocks is not None:
|
if self.full_kv_num_blocks is not None:
|
||||||
assert self.full_kv_indices is not None
|
assert self.full_kv_indices is not None
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return partial_dense | _ordered_to_dense(
|
return partial_dense | _ordered_to_dense(
|
||||||
self.full_kv_num_blocks, self.full_kv_indices
|
self.full_kv_num_blocks, self.full_kv_indices
|
||||||
)
|
)
|
||||||
|
@ -78,6 +78,7 @@ class ModuleWrapper(nn.Module):
|
|||||||
|
|
||||||
# nn.Module defines training as a boolean
|
# nn.Module defines training as a boolean
|
||||||
@property # type: ignore[override]
|
@property # type: ignore[override]
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def training(self):
|
def training(self):
|
||||||
return self.cpp_module.training
|
return self.cpp_module.training
|
||||||
|
|
||||||
|
@ -1266,6 +1266,7 @@ def adaptive_max_pool2d_with_indices(
|
|||||||
output_size,
|
output_size,
|
||||||
return_indices=return_indices,
|
return_indices=return_indices,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
output_size = _list_with_default(output_size, input.size())
|
output_size = _list_with_default(output_size, input.size())
|
||||||
return torch._C._nn.adaptive_max_pool2d(input, output_size)
|
return torch._C._nn.adaptive_max_pool2d(input, output_size)
|
||||||
|
|
||||||
@ -1323,6 +1324,7 @@ def adaptive_max_pool3d_with_indices(
|
|||||||
output_size,
|
output_size,
|
||||||
return_indices=return_indices,
|
return_indices=return_indices,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
output_size = _list_with_default(output_size, input.size())
|
output_size = _list_with_default(output_size, input.size())
|
||||||
return torch._C._nn.adaptive_max_pool3d(input, output_size)
|
return torch._C._nn.adaptive_max_pool3d(input, output_size)
|
||||||
|
|
||||||
@ -1381,6 +1383,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> T
|
|||||||
"""
|
"""
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
|
return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_output_size = _list_with_default(output_size, input.size())
|
_output_size = _list_with_default(output_size, input.size())
|
||||||
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
|
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
|
||||||
|
|
||||||
@ -1396,6 +1399,7 @@ def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> T
|
|||||||
"""
|
"""
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
|
return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_output_size = _list_with_default(output_size, input.size())
|
_output_size = _list_with_default(output_size, input.size())
|
||||||
return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
|
return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
|
||||||
|
|
||||||
@ -2431,6 +2435,7 @@ def _no_grad_embedding_renorm_(
|
|||||||
input: Tensor,
|
input: Tensor,
|
||||||
max_norm: float,
|
max_norm: float,
|
||||||
norm_type: float,
|
norm_type: float,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> tuple[Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
|
torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
|
||||||
|
|
||||||
@ -2684,6 +2689,7 @@ def embedding_bag(
|
|||||||
|
|
||||||
if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested:
|
if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested:
|
||||||
include_last_offset = True
|
include_last_offset = True
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
offsets = input.offsets()
|
offsets = input.offsets()
|
||||||
input = input.values().reshape(-1)
|
input = input.values().reshape(-1)
|
||||||
if per_sample_weights is not None:
|
if per_sample_weights is not None:
|
||||||
@ -2818,6 +2824,7 @@ def batch_norm(
|
|||||||
eps=eps,
|
eps=eps,
|
||||||
)
|
)
|
||||||
if training:
|
if training:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_verify_batch_size(input.size())
|
_verify_batch_size(input.size())
|
||||||
|
|
||||||
return torch.batch_norm(
|
return torch.batch_norm(
|
||||||
@ -2873,6 +2880,7 @@ def instance_norm(
|
|||||||
eps=eps,
|
eps=eps,
|
||||||
)
|
)
|
||||||
if use_input_stats:
|
if use_input_stats:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_verify_spatial_size(input.size())
|
_verify_spatial_size(input.size())
|
||||||
return torch.instance_norm(
|
return torch.instance_norm(
|
||||||
input,
|
input,
|
||||||
@ -2998,11 +3006,13 @@ def local_response_norm(
|
|||||||
div = input.mul(input)
|
div = input.mul(input)
|
||||||
if dim == 3:
|
if dim == 3:
|
||||||
div = div.unsqueeze(1)
|
div = div.unsqueeze(1)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
div = pad(div, (0, 0, size // 2, (size - 1) // 2))
|
div = pad(div, (0, 0, size // 2, (size - 1) // 2))
|
||||||
div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
|
div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
|
||||||
else:
|
else:
|
||||||
sizes = input.size()
|
sizes = input.size()
|
||||||
div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
|
div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
|
div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
|
||||||
div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
|
div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
|
||||||
div = div.view(sizes)
|
div = div.view(sizes)
|
||||||
@ -3151,7 +3161,12 @@ def nll_loss(
|
|||||||
if size_average is not None or reduce is not None:
|
if size_average is not None or reduce is not None:
|
||||||
reduction = _Reduction.legacy_get_string(size_average, reduce)
|
reduction = _Reduction.legacy_get_string(size_average, reduce)
|
||||||
return torch._C._nn.nll_loss_nd(
|
return torch._C._nn.nll_loss_nd(
|
||||||
input, target, weight, _Reduction.get_enum(reduction), ignore_index
|
input,
|
||||||
|
target,
|
||||||
|
weight,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
_Reduction.get_enum(reduction),
|
||||||
|
ignore_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -3296,6 +3311,7 @@ def gaussian_nll_loss(
|
|||||||
var.clamp_(min=eps)
|
var.clamp_(min=eps)
|
||||||
|
|
||||||
# Calculate the loss
|
# Calculate the loss
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var)
|
loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var)
|
||||||
if full:
|
if full:
|
||||||
loss += 0.5 * math.log(2 * math.pi)
|
loss += 0.5 * math.log(2 * math.pi)
|
||||||
@ -3471,6 +3487,7 @@ def cross_entropy(
|
|||||||
input,
|
input,
|
||||||
target,
|
target,
|
||||||
weight,
|
weight,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_Reduction.get_enum(reduction),
|
_Reduction.get_enum(reduction),
|
||||||
ignore_index,
|
ignore_index,
|
||||||
label_smoothing,
|
label_smoothing,
|
||||||
@ -3535,6 +3552,7 @@ def binary_cross_entropy(
|
|||||||
new_size = _infer_size(target.size(), weight.size())
|
new_size = _infer_size(target.size(), weight.size())
|
||||||
weight = weight.expand(new_size)
|
weight = weight.expand(new_size)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
|
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
|
||||||
|
|
||||||
|
|
||||||
@ -3663,11 +3681,18 @@ def smooth_l1_loss(
|
|||||||
|
|
||||||
if beta == 0.0:
|
if beta == 0.0:
|
||||||
return torch._C._nn.l1_loss(
|
return torch._C._nn.l1_loss(
|
||||||
expanded_input, expanded_target, _Reduction.get_enum(reduction)
|
expanded_input,
|
||||||
|
expanded_target,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
_Reduction.get_enum(reduction),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return torch._C._nn.smooth_l1_loss(
|
return torch._C._nn.smooth_l1_loss(
|
||||||
expanded_input, expanded_target, _Reduction.get_enum(reduction), beta
|
expanded_input,
|
||||||
|
expanded_target,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
_Reduction.get_enum(reduction),
|
||||||
|
beta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -3725,7 +3750,11 @@ def huber_loss(
|
|||||||
if weight is None:
|
if weight is None:
|
||||||
# Use the optimized C++ backend for standard Huber loss
|
# Use the optimized C++ backend for standard Huber loss
|
||||||
return torch._C._nn.huber_loss(
|
return torch._C._nn.huber_loss(
|
||||||
expanded_input, expanded_target, _Reduction.get_enum(reduction), delta
|
expanded_input,
|
||||||
|
expanded_target,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
_Reduction.get_enum(reduction),
|
||||||
|
delta,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if weight.size() != input.size():
|
if weight.size() != input.size():
|
||||||
@ -3733,7 +3762,11 @@ def huber_loss(
|
|||||||
|
|
||||||
# Calculate the unweighted loss first
|
# Calculate the unweighted loss first
|
||||||
unweighted_loss = torch._C._nn.huber_loss(
|
unweighted_loss = torch._C._nn.huber_loss(
|
||||||
expanded_input, expanded_target, _Reduction.get_enum("none"), delta
|
expanded_input,
|
||||||
|
expanded_target,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
_Reduction.get_enum("none"),
|
||||||
|
delta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply weight to the unweighted loss
|
# Apply weight to the unweighted loss
|
||||||
@ -3820,7 +3853,10 @@ def l1_loss(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return torch._C._nn.l1_loss(
|
return torch._C._nn.l1_loss(
|
||||||
expanded_input, expanded_target, _Reduction.get_enum(reduction)
|
expanded_input,
|
||||||
|
expanded_target,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
_Reduction.get_enum(reduction),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -3895,7 +3931,10 @@ def mse_loss(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return torch._C._nn.mse_loss(
|
return torch._C._nn.mse_loss(
|
||||||
expanded_input, expanded_target, _Reduction.get_enum(reduction)
|
expanded_input,
|
||||||
|
expanded_target,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
_Reduction.get_enum(reduction),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -4032,6 +4071,7 @@ def multilabel_margin_loss(
|
|||||||
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
|
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
|
||||||
else:
|
else:
|
||||||
reduction_enum = _Reduction.get_enum(reduction)
|
reduction_enum = _Reduction.get_enum(reduction)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
|
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
|
||||||
|
|
||||||
|
|
||||||
@ -4073,6 +4113,7 @@ def soft_margin_loss(
|
|||||||
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
|
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
|
||||||
else:
|
else:
|
||||||
reduction_enum = _Reduction.get_enum(reduction)
|
reduction_enum = _Reduction.get_enum(reduction)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
|
return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
|
||||||
|
|
||||||
|
|
||||||
@ -4237,7 +4278,13 @@ def multi_margin_loss(
|
|||||||
raise ValueError("weight must be one-dimensional")
|
raise ValueError("weight must be one-dimensional")
|
||||||
|
|
||||||
return torch._C._nn.multi_margin_loss(
|
return torch._C._nn.multi_margin_loss(
|
||||||
input, target, p, margin, weight, reduction_enum
|
input,
|
||||||
|
target,
|
||||||
|
p,
|
||||||
|
margin,
|
||||||
|
weight,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
reduction_enum,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -4383,6 +4430,7 @@ def upsample( # noqa: F811
|
|||||||
scale_factor: Optional[float] = None,
|
scale_factor: Optional[float] = None,
|
||||||
mode: str = "nearest",
|
mode: str = "nearest",
|
||||||
align_corners: Optional[bool] = None,
|
align_corners: Optional[bool] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor: # noqa: B950
|
) -> Tensor: # noqa: B950
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4394,6 +4442,7 @@ def upsample( # noqa: F811
|
|||||||
scale_factor: Optional[float] = None,
|
scale_factor: Optional[float] = None,
|
||||||
mode: str = "nearest",
|
mode: str = "nearest",
|
||||||
align_corners: Optional[bool] = None,
|
align_corners: Optional[bool] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor: # noqa: B950
|
) -> Tensor: # noqa: B950
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4496,6 +4545,7 @@ def interpolate( # noqa: F811
|
|||||||
align_corners: Optional[bool] = None,
|
align_corners: Optional[bool] = None,
|
||||||
recompute_scale_factor: Optional[bool] = None,
|
recompute_scale_factor: Optional[bool] = None,
|
||||||
antialias: bool = False,
|
antialias: bool = False,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor: # noqa: B950
|
) -> Tensor: # noqa: B950
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4509,6 +4559,7 @@ def interpolate( # noqa: F811
|
|||||||
align_corners: Optional[bool] = None,
|
align_corners: Optional[bool] = None,
|
||||||
recompute_scale_factor: Optional[bool] = None,
|
recompute_scale_factor: Optional[bool] = None,
|
||||||
antialias: bool = False,
|
antialias: bool = False,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor: # noqa: B950
|
) -> Tensor: # noqa: B950
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4522,6 +4573,7 @@ def interpolate( # noqa: F811
|
|||||||
align_corners: Optional[bool] = None,
|
align_corners: Optional[bool] = None,
|
||||||
recompute_scale_factor: Optional[bool] = None,
|
recompute_scale_factor: Optional[bool] = None,
|
||||||
antialias: bool = False,
|
antialias: bool = False,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor: # noqa: B950
|
) -> Tensor: # noqa: B950
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4535,6 +4587,7 @@ def interpolate( # noqa: F811
|
|||||||
align_corners: Optional[bool] = None,
|
align_corners: Optional[bool] = None,
|
||||||
recompute_scale_factor: Optional[bool] = None,
|
recompute_scale_factor: Optional[bool] = None,
|
||||||
antialias: bool = False,
|
antialias: bool = False,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4709,6 +4762,7 @@ def interpolate( # noqa: F811
|
|||||||
(
|
(
|
||||||
torch.floor(
|
torch.floor(
|
||||||
(
|
(
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
input.size(i + 2).float()
|
input.size(i + 2).float()
|
||||||
* torch.tensor(scale_factors[i], dtype=torch.float32)
|
* torch.tensor(scale_factors[i], dtype=torch.float32)
|
||||||
).float()
|
).float()
|
||||||
@ -4733,21 +4787,28 @@ def interpolate( # noqa: F811
|
|||||||
)
|
)
|
||||||
|
|
||||||
if input.dim() == 3 and mode == "nearest":
|
if input.dim() == 3 and mode == "nearest":
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
|
return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
|
||||||
if input.dim() == 4 and mode == "nearest":
|
if input.dim() == 4 and mode == "nearest":
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
|
return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
|
||||||
if input.dim() == 5 and mode == "nearest":
|
if input.dim() == 5 and mode == "nearest":
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
|
return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
|
||||||
|
|
||||||
if input.dim() == 3 and mode == "nearest-exact":
|
if input.dim() == 3 and mode == "nearest-exact":
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
|
return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
|
||||||
if input.dim() == 4 and mode == "nearest-exact":
|
if input.dim() == 4 and mode == "nearest-exact":
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
|
return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
|
||||||
if input.dim() == 5 and mode == "nearest-exact":
|
if input.dim() == 5 and mode == "nearest-exact":
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
|
return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
|
||||||
|
|
||||||
if input.dim() == 3 and mode == "area":
|
if input.dim() == 3 and mode == "area":
|
||||||
assert output_size is not None
|
assert output_size is not None
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return adaptive_avg_pool1d(input, output_size)
|
return adaptive_avg_pool1d(input, output_size)
|
||||||
if input.dim() == 4 and mode == "area":
|
if input.dim() == 4 and mode == "area":
|
||||||
assert output_size is not None
|
assert output_size is not None
|
||||||
@ -4759,13 +4820,21 @@ def interpolate( # noqa: F811
|
|||||||
if input.dim() == 3 and mode == "linear":
|
if input.dim() == 3 and mode == "linear":
|
||||||
assert align_corners is not None
|
assert align_corners is not None
|
||||||
return torch._C._nn.upsample_linear1d(
|
return torch._C._nn.upsample_linear1d(
|
||||||
input, output_size, align_corners, scale_factors
|
input,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
output_size,
|
||||||
|
align_corners,
|
||||||
|
scale_factors,
|
||||||
)
|
)
|
||||||
if input.dim() == 4 and mode == "bilinear":
|
if input.dim() == 4 and mode == "bilinear":
|
||||||
assert align_corners is not None
|
assert align_corners is not None
|
||||||
if antialias:
|
if antialias:
|
||||||
return torch._C._nn._upsample_bilinear2d_aa(
|
return torch._C._nn._upsample_bilinear2d_aa(
|
||||||
input, output_size, align_corners, scale_factors
|
input,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
output_size,
|
||||||
|
align_corners,
|
||||||
|
scale_factors,
|
||||||
)
|
)
|
||||||
# Two levels are necessary to prevent TorchScript from touching
|
# Two levels are necessary to prevent TorchScript from touching
|
||||||
# are_deterministic_algorithms_enabled.
|
# are_deterministic_algorithms_enabled.
|
||||||
@ -4778,7 +4847,11 @@ def interpolate( # noqa: F811
|
|||||||
"torch._decomp.decompositions"
|
"torch._decomp.decompositions"
|
||||||
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
|
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
|
||||||
return torch._C._nn.upsample_bilinear2d(
|
return torch._C._nn.upsample_bilinear2d(
|
||||||
input, output_size, align_corners, scale_factors
|
input,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
output_size,
|
||||||
|
align_corners,
|
||||||
|
scale_factors,
|
||||||
)
|
)
|
||||||
if input.dim() == 5 and mode == "trilinear":
|
if input.dim() == 5 and mode == "trilinear":
|
||||||
assert align_corners is not None
|
assert align_corners is not None
|
||||||
@ -4793,16 +4866,28 @@ def interpolate( # noqa: F811
|
|||||||
"torch._decomp.decompositions"
|
"torch._decomp.decompositions"
|
||||||
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
|
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
|
||||||
return torch._C._nn.upsample_trilinear3d(
|
return torch._C._nn.upsample_trilinear3d(
|
||||||
input, output_size, align_corners, scale_factors
|
input,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
output_size,
|
||||||
|
align_corners,
|
||||||
|
scale_factors,
|
||||||
)
|
)
|
||||||
if input.dim() == 4 and mode == "bicubic":
|
if input.dim() == 4 and mode == "bicubic":
|
||||||
assert align_corners is not None
|
assert align_corners is not None
|
||||||
if antialias:
|
if antialias:
|
||||||
return torch._C._nn._upsample_bicubic2d_aa(
|
return torch._C._nn._upsample_bicubic2d_aa(
|
||||||
input, output_size, align_corners, scale_factors
|
input,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
output_size,
|
||||||
|
align_corners,
|
||||||
|
scale_factors,
|
||||||
)
|
)
|
||||||
return torch._C._nn.upsample_bicubic2d(
|
return torch._C._nn.upsample_bicubic2d(
|
||||||
input, output_size, align_corners, scale_factors
|
input,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
output_size,
|
||||||
|
align_corners,
|
||||||
|
scale_factors,
|
||||||
)
|
)
|
||||||
|
|
||||||
if input.dim() == 3 and mode == "bilinear":
|
if input.dim() == 3 and mode == "bilinear":
|
||||||
@ -4834,6 +4919,7 @@ def upsample_nearest( # noqa: F811
|
|||||||
input: Tensor,
|
input: Tensor,
|
||||||
size: Optional[int] = None,
|
size: Optional[int] = None,
|
||||||
scale_factor: Optional[float] = None,
|
scale_factor: Optional[float] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4843,6 +4929,7 @@ def upsample_nearest( # noqa: F811
|
|||||||
input: Tensor,
|
input: Tensor,
|
||||||
size: Optional[list[int]] = None,
|
size: Optional[list[int]] = None,
|
||||||
scale_factor: Optional[float] = None,
|
scale_factor: Optional[float] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4884,6 +4971,7 @@ def upsample_bilinear( # noqa: F811
|
|||||||
input: Tensor,
|
input: Tensor,
|
||||||
size: Optional[int] = None,
|
size: Optional[int] = None,
|
||||||
scale_factor: Optional[float] = None,
|
scale_factor: Optional[float] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4893,6 +4981,7 @@ def upsample_bilinear( # noqa: F811
|
|||||||
input: Tensor,
|
input: Tensor,
|
||||||
size: Optional[list[int]] = None,
|
size: Optional[list[int]] = None,
|
||||||
scale_factor: Optional[float] = None,
|
scale_factor: Optional[float] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4902,6 +4991,7 @@ def upsample_bilinear( # noqa: F811
|
|||||||
input: Tensor,
|
input: Tensor,
|
||||||
size: Optional[int] = None,
|
size: Optional[int] = None,
|
||||||
scale_factor: Optional[list[float]] = None,
|
scale_factor: Optional[list[float]] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -4911,6 +5001,7 @@ def upsample_bilinear( # noqa: F811
|
|||||||
input: Tensor,
|
input: Tensor,
|
||||||
size: Optional[list[int]] = None,
|
size: Optional[list[int]] = None,
|
||||||
scale_factor: Optional[list[float]] = None,
|
scale_factor: Optional[list[float]] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -5717,6 +5808,7 @@ def _in_projection_packed(
|
|||||||
.squeeze(-2)
|
.squeeze(-2)
|
||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return proj[0], proj[1], proj[2]
|
return proj[0], proj[1], proj[2]
|
||||||
else:
|
else:
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -5735,6 +5827,7 @@ def _in_projection_packed(
|
|||||||
.squeeze(-2)
|
.squeeze(-2)
|
||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return (q_proj, kv_proj[0], kv_proj[1])
|
return (q_proj, kv_proj[0], kv_proj[1])
|
||||||
else:
|
else:
|
||||||
w_q, w_k, w_v = w.chunk(3)
|
w_q, w_k, w_v = w.chunk(3)
|
||||||
@ -5742,6 +5835,7 @@ def _in_projection_packed(
|
|||||||
b_q = b_k = b_v = None
|
b_q = b_k = b_v = None
|
||||||
else:
|
else:
|
||||||
b_q, b_k, b_v = b.chunk(3)
|
b_q, b_k, b_v = b.chunk(3)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
||||||
|
|
||||||
|
|
||||||
@ -6372,8 +6466,10 @@ def multi_head_attention_forward(
|
|||||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
attn_mask = pad(attn_mask, (0, 1))
|
attn_mask = pad(attn_mask, (0, 1))
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
key_padding_mask = pad(key_padding_mask, (0, 1))
|
key_padding_mask = pad(key_padding_mask, (0, 1))
|
||||||
else:
|
else:
|
||||||
assert bias_k is None
|
assert bias_k is None
|
||||||
@ -6382,8 +6478,10 @@ def multi_head_attention_forward(
|
|||||||
#
|
#
|
||||||
# reshape q, k, v for multihead attention and make them batch first
|
# reshape q, k, v for multihead attention and make them batch first
|
||||||
#
|
#
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
if static_k is None:
|
if static_k is None:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||||
@ -6395,6 +6493,7 @@ def multi_head_attention_forward(
|
|||||||
)
|
)
|
||||||
k = static_k
|
k = static_k
|
||||||
if static_v is None:
|
if static_v is None:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||||
@ -6410,14 +6509,20 @@ def multi_head_attention_forward(
|
|||||||
if add_zero_attn:
|
if add_zero_attn:
|
||||||
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||||
k = torch.cat(
|
k = torch.cat(
|
||||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
|
# pyrefly: ignore # no-matching-overload
|
||||||
|
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
|
||||||
|
dim=1,
|
||||||
)
|
)
|
||||||
v = torch.cat(
|
v = torch.cat(
|
||||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
|
# pyrefly: ignore # no-matching-overload
|
||||||
|
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
|
||||||
|
dim=1,
|
||||||
)
|
)
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
attn_mask = pad(attn_mask, (0, 1))
|
attn_mask = pad(attn_mask, (0, 1))
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
key_padding_mask = pad(key_padding_mask, (0, 1))
|
key_padding_mask = pad(key_padding_mask, (0, 1))
|
||||||
|
|
||||||
# update source sequence length after adjustments
|
# update source sequence length after adjustments
|
||||||
@ -6467,6 +6572,7 @@ def multi_head_attention_forward(
|
|||||||
attn_output = torch.bmm(attn_output_weights, v)
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
|
|
||||||
attn_output = (
|
attn_output = (
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||||
)
|
)
|
||||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
@ -6493,13 +6599,16 @@ def multi_head_attention_forward(
|
|||||||
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
||||||
|
|
||||||
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
k = k.view(bsz, num_heads, src_len, head_dim)
|
k = k.view(bsz, num_heads, src_len, head_dim)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||||
|
|
||||||
attn_output = scaled_dot_product_attention(
|
attn_output = scaled_dot_product_attention(
|
||||||
q, k, v, attn_mask, dropout_p, is_causal
|
q, k, v, attn_mask, dropout_p, is_causal
|
||||||
)
|
)
|
||||||
attn_output = (
|
attn_output = (
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -500,6 +500,7 @@ def xavier_normal_(
|
|||||||
|
|
||||||
|
|
||||||
def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int:
|
def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
mode = mode.lower()
|
mode = mode.lower()
|
||||||
valid_modes = ["fan_in", "fan_out"]
|
valid_modes = ["fan_in", "fan_out"]
|
||||||
if mode not in valid_modes:
|
if mode not in valid_modes:
|
||||||
|
@ -6,6 +6,7 @@ from torch.autograd.function import Function
|
|||||||
|
|
||||||
class SyncBatchNorm(Function):
|
class SyncBatchNorm(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input,
|
input,
|
||||||
@ -210,6 +211,7 @@ class SyncBatchNorm(Function):
|
|||||||
|
|
||||||
class CrossMapLRN2d(Function):
|
class CrossMapLRN2d(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
|
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
|
||||||
ctx.size = size
|
ctx.size = size
|
||||||
ctx.alpha = alpha
|
ctx.alpha = alpha
|
||||||
@ -265,6 +267,7 @@ class CrossMapLRN2d(Function):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, output = ctx.saved_tensors
|
input, output = ctx.saved_tensors
|
||||||
grad_input = grad_output.new()
|
grad_input = grad_output.new()
|
||||||
@ -306,6 +309,7 @@ class CrossMapLRN2d(Function):
|
|||||||
|
|
||||||
class BackwardHookFunction(torch.autograd.Function):
|
class BackwardHookFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, *args):
|
def forward(ctx, *args):
|
||||||
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
|
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
|
||||||
return args
|
return args
|
||||||
|
@ -72,6 +72,7 @@ class _NormBase(Module):
|
|||||||
torch.tensor(
|
torch.tensor(
|
||||||
0,
|
0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -221,6 +222,7 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(
|
super().__init__(
|
||||||
# affine and track_running_stats are hardcoded to False to
|
# affine and track_running_stats are hardcoded to False to
|
||||||
# avoid creating tensors that will soon be overwritten.
|
# avoid creating tensors that will soon be overwritten.
|
||||||
@ -234,22 +236,29 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
|
|||||||
self.affine = affine
|
self.affine = affine
|
||||||
self.track_running_stats = track_running_stats
|
self.track_running_stats = track_running_stats
|
||||||
if self.affine:
|
if self.affine:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.weight = UninitializedParameter(**factory_kwargs)
|
self.weight = UninitializedParameter(**factory_kwargs)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.bias = UninitializedParameter(**factory_kwargs)
|
self.bias = UninitializedParameter(**factory_kwargs)
|
||||||
if self.track_running_stats:
|
if self.track_running_stats:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.running_mean = UninitializedBuffer(**factory_kwargs)
|
self.running_mean = UninitializedBuffer(**factory_kwargs)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.running_var = UninitializedBuffer(**factory_kwargs)
|
self.running_var = UninitializedBuffer(**factory_kwargs)
|
||||||
self.num_batches_tracked = torch.tensor(
|
self.num_batches_tracked = torch.tensor(
|
||||||
0,
|
0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
def reset_parameters(self) -> None:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if not self.has_uninitialized_params() and self.num_features != 0:
|
if not self.has_uninitialized_params() and self.num_features != 0:
|
||||||
super().reset_parameters()
|
super().reset_parameters()
|
||||||
|
|
||||||
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if self.has_uninitialized_params():
|
if self.has_uninitialized_params():
|
||||||
self.num_features = input.shape[1]
|
self.num_features = input.shape[1]
|
||||||
if self.affine:
|
if self.affine:
|
||||||
|
@ -109,6 +109,7 @@ class Sequential(Module):
|
|||||||
def __init__(self, *args: Module) -> None: ...
|
def __init__(self, *args: Module) -> None: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
|
def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
@ -472,6 +473,7 @@ class ModuleList(Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def pop(self, key: Union[int, slice]) -> Module:
|
def pop(self, key: Union[int, slice]) -> Module:
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
v = self[key]
|
v = self[key]
|
||||||
del self[key]
|
del self[key]
|
||||||
return v
|
return v
|
||||||
@ -623,9 +625,11 @@ class ModuleDict(Module):
|
|||||||
"ModuleDict update sequence element "
|
"ModuleDict update sequence element "
|
||||||
"#" + str(j) + " should be Iterable; is" + type(m).__name__
|
"#" + str(j) + " should be Iterable; is" + type(m).__name__
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if not len(m) == 2:
|
if not len(m) == 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ModuleDict update sequence element "
|
"ModuleDict update sequence element "
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
|
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
|
||||||
)
|
)
|
||||||
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
|
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
|
||||||
@ -684,6 +688,7 @@ class ParameterList(Module):
|
|||||||
def __getitem__(self, idx: int) -> Any: ...
|
def __getitem__(self, idx: int) -> Any: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
# pyrefly: ignore # inconsistent-overload
|
||||||
def __getitem__(self: T, idx: slice) -> T: ...
|
def __getitem__(self: T, idx: slice) -> T: ...
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
@ -769,9 +774,11 @@ class ParameterList(Module):
|
|||||||
size_str,
|
size_str,
|
||||||
device_str,
|
device_str,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
child_lines.append(" (" + str(k) + "): " + parastr)
|
child_lines.append(" (" + str(k) + "): " + parastr)
|
||||||
else:
|
else:
|
||||||
child_lines.append(
|
child_lines.append(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
" (" + str(k) + "): Object of type: " + type(p).__name__
|
" (" + str(k) + "): Object of type: " + type(p).__name__
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -979,9 +986,11 @@ class ParameterDict(Module):
|
|||||||
"ParameterDict update sequence element "
|
"ParameterDict update sequence element "
|
||||||
"#" + str(j) + " should be Iterable; is" + type(p).__name__
|
"#" + str(j) + " should be Iterable; is" + type(p).__name__
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if not len(p) == 2:
|
if not len(p) == 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ParameterDict update sequence element "
|
"ParameterDict update sequence element "
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
|
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
|
||||||
)
|
)
|
||||||
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
|
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
|
||||||
@ -1002,9 +1011,11 @@ class ParameterDict(Module):
|
|||||||
size_str,
|
size_str,
|
||||||
device_str,
|
device_str,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
child_lines.append(" (" + str(k) + "): " + parastr)
|
child_lines.append(" (" + str(k) + "): " + parastr)
|
||||||
else:
|
else:
|
||||||
child_lines.append(
|
child_lines.append(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
" (" + str(k) + "): Object of type: " + type(p).__name__
|
" (" + str(k) + "): Object of type: " + type(p).__name__
|
||||||
)
|
)
|
||||||
tmpstr = "\n".join(child_lines)
|
tmpstr = "\n".join(child_lines)
|
||||||
|
@ -363,6 +363,7 @@ class Conv1d(_ConvNd):
|
|||||||
self.dilation,
|
self.dilation,
|
||||||
self.groups,
|
self.groups,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return F.conv1d(
|
return F.conv1d(
|
||||||
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
||||||
)
|
)
|
||||||
@ -540,6 +541,7 @@ class Conv2d(_ConvNd):
|
|||||||
self.dilation,
|
self.dilation,
|
||||||
self.groups,
|
self.groups,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return F.conv2d(
|
return F.conv2d(
|
||||||
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
||||||
)
|
)
|
||||||
@ -709,6 +711,7 @@ class Conv3d(_ConvNd):
|
|||||||
self.dilation,
|
self.dilation,
|
||||||
self.groups,
|
self.groups,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return F.conv3d(
|
return F.conv3d(
|
||||||
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
||||||
)
|
)
|
||||||
@ -1494,6 +1497,7 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(
|
super().__init__(
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@ -1508,9 +1512,11 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
|||||||
padding_mode,
|
padding_mode,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.weight = UninitializedParameter(**factory_kwargs)
|
self.weight = UninitializedParameter(**factory_kwargs)
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
if bias:
|
if bias:
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.bias = UninitializedParameter(**factory_kwargs)
|
self.bias = UninitializedParameter(**factory_kwargs)
|
||||||
|
|
||||||
def _get_num_spatial_dims(self) -> int:
|
def _get_num_spatial_dims(self) -> int:
|
||||||
@ -1563,6 +1569,7 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(
|
super().__init__(
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@ -1577,9 +1584,11 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
|||||||
padding_mode,
|
padding_mode,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.weight = UninitializedParameter(**factory_kwargs)
|
self.weight = UninitializedParameter(**factory_kwargs)
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
if bias:
|
if bias:
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.bias = UninitializedParameter(**factory_kwargs)
|
self.bias = UninitializedParameter(**factory_kwargs)
|
||||||
|
|
||||||
def _get_num_spatial_dims(self) -> int:
|
def _get_num_spatial_dims(self) -> int:
|
||||||
@ -1633,6 +1642,7 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(
|
super().__init__(
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@ -1647,9 +1657,11 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
|||||||
padding_mode,
|
padding_mode,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.weight = UninitializedParameter(**factory_kwargs)
|
self.weight = UninitializedParameter(**factory_kwargs)
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
if bias:
|
if bias:
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.bias = UninitializedParameter(**factory_kwargs)
|
self.bias = UninitializedParameter(**factory_kwargs)
|
||||||
|
|
||||||
def _get_num_spatial_dims(self) -> int:
|
def _get_num_spatial_dims(self) -> int:
|
||||||
@ -1701,6 +1713,7 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(
|
super().__init__(
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@ -1716,9 +1729,11 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
|
|||||||
padding_mode,
|
padding_mode,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.weight = UninitializedParameter(**factory_kwargs)
|
self.weight = UninitializedParameter(**factory_kwargs)
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
if bias:
|
if bias:
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.bias = UninitializedParameter(**factory_kwargs)
|
self.bias = UninitializedParameter(**factory_kwargs)
|
||||||
|
|
||||||
def _get_num_spatial_dims(self) -> int:
|
def _get_num_spatial_dims(self) -> int:
|
||||||
@ -1770,6 +1785,7 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(
|
super().__init__(
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@ -1785,9 +1801,11 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
|
|||||||
padding_mode,
|
padding_mode,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.weight = UninitializedParameter(**factory_kwargs)
|
self.weight = UninitializedParameter(**factory_kwargs)
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
if bias:
|
if bias:
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.bias = UninitializedParameter(**factory_kwargs)
|
self.bias = UninitializedParameter(**factory_kwargs)
|
||||||
|
|
||||||
def _get_num_spatial_dims(self) -> int:
|
def _get_num_spatial_dims(self) -> int:
|
||||||
@ -1839,6 +1857,7 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(
|
super().__init__(
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
@ -1854,9 +1873,11 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
|
|||||||
padding_mode,
|
padding_mode,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.weight = UninitializedParameter(**factory_kwargs)
|
self.weight = UninitializedParameter(**factory_kwargs)
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
if bias:
|
if bias:
|
||||||
|
# pyrefly: ignore # bad-override, bad-argument-type
|
||||||
self.bias = UninitializedParameter(**factory_kwargs)
|
self.bias = UninitializedParameter(**factory_kwargs)
|
||||||
|
|
||||||
def _get_num_spatial_dims(self) -> int:
|
def _get_num_spatial_dims(self) -> int:
|
||||||
|
@ -172,7 +172,9 @@ class LazyModuleMixin:
|
|||||||
def __init__(self: _LazyProtocol, *args, **kwargs):
|
def __init__(self: _LazyProtocol, *args, **kwargs):
|
||||||
# Mypy doesn't like this super call in a mixin
|
# Mypy doesn't like this super call in a mixin
|
||||||
super().__init__(*args, **kwargs) # type: ignore[misc]
|
super().__init__(*args, **kwargs) # type: ignore[misc]
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
|
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self._initialize_hook = self.register_forward_pre_hook(
|
self._initialize_hook = self.register_forward_pre_hook(
|
||||||
self._infer_parameters, with_kwargs=True
|
self._infer_parameters, with_kwargs=True
|
||||||
)
|
)
|
||||||
|
@ -286,6 +286,7 @@ class LazyLinear(LazyModuleMixin, Linear):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
cls_to_become = Linear # type: ignore[assignment]
|
cls_to_become = Linear # type: ignore[assignment]
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
weight: UninitializedParameter
|
weight: UninitializedParameter
|
||||||
bias: UninitializedParameter # type: ignore[assignment]
|
bias: UninitializedParameter # type: ignore[assignment]
|
||||||
|
|
||||||
@ -295,16 +296,20 @@ class LazyLinear(LazyModuleMixin, Linear):
|
|||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
# bias is hardcoded to False to avoid creating tensor
|
# bias is hardcoded to False to avoid creating tensor
|
||||||
# that will soon be overwritten.
|
# that will soon be overwritten.
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
super().__init__(0, 0, False)
|
super().__init__(0, 0, False)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.weight = UninitializedParameter(**factory_kwargs)
|
self.weight = UninitializedParameter(**factory_kwargs)
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
if bias:
|
if bias:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.bias = UninitializedParameter(**factory_kwargs)
|
self.bias = UninitializedParameter(**factory_kwargs)
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
def reset_parameters(self) -> None:
|
||||||
"""
|
"""
|
||||||
Resets parameters based on their initialization used in ``__init__``.
|
Resets parameters based on their initialization used in ``__init__``.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if not self.has_uninitialized_params() and self.in_features != 0:
|
if not self.has_uninitialized_params() and self.in_features != 0:
|
||||||
super().reset_parameters()
|
super().reset_parameters()
|
||||||
|
|
||||||
@ -312,6 +317,7 @@ class LazyLinear(LazyModuleMixin, Linear):
|
|||||||
"""
|
"""
|
||||||
Infers ``in_features`` based on ``input`` and initializes parameters.
|
Infers ``in_features`` based on ``input`` and initializes parameters.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if self.has_uninitialized_params():
|
if self.has_uninitialized_params():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.in_features = input.shape[-1]
|
self.in_features = input.shape[-1]
|
||||||
|
@ -38,11 +38,13 @@ T = TypeVar("T", bound="Module")
|
|||||||
|
|
||||||
|
|
||||||
class _IncompatibleKeys(
|
class _IncompatibleKeys(
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
||||||
):
|
):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if not self.missing_keys and not self.unexpected_keys:
|
if not self.missing_keys and not self.unexpected_keys:
|
||||||
return "<All keys matched successfully>"
|
return "<All keys matched successfully>"
|
||||||
return super().__repr__()
|
return super().__repr__()
|
||||||
@ -91,6 +93,7 @@ class _WrappedHook:
|
|||||||
def __getstate__(self) -> dict:
|
def __getstate__(self) -> dict:
|
||||||
result = {"hook": self.hook, "with_module": self.with_module}
|
result = {"hook": self.hook, "with_module": self.with_module}
|
||||||
if self.with_module:
|
if self.with_module:
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
result["module"] = self.module()
|
result["module"] = self.module()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -976,7 +979,9 @@ class Module:
|
|||||||
# Decrement use count of the gradient by setting to None
|
# Decrement use count of the gradient by setting to None
|
||||||
param.grad = None
|
param.grad = None
|
||||||
param_applied = torch.nn.Parameter(
|
param_applied = torch.nn.Parameter(
|
||||||
param_applied, requires_grad=param.requires_grad
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
param_applied,
|
||||||
|
requires_grad=param.requires_grad,
|
||||||
)
|
)
|
||||||
torch.utils.swap_tensors(param, param_applied)
|
torch.utils.swap_tensors(param, param_applied)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -987,11 +992,13 @@ class Module:
|
|||||||
) from e
|
) from e
|
||||||
out_param = param
|
out_param = param
|
||||||
elif p_should_use_set_data:
|
elif p_should_use_set_data:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
param.data = param_applied
|
param.data = param_applied
|
||||||
out_param = param
|
out_param = param
|
||||||
else:
|
else:
|
||||||
assert isinstance(param, Parameter)
|
assert isinstance(param, Parameter)
|
||||||
assert param.is_leaf
|
assert param.is_leaf
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
out_param = Parameter(param_applied, param.requires_grad)
|
out_param = Parameter(param_applied, param.requires_grad)
|
||||||
self._parameters[key] = out_param
|
self._parameters[key] = out_param
|
||||||
|
|
||||||
@ -2253,6 +2260,7 @@ class Module:
|
|||||||
|
|
||||||
if destination is None:
|
if destination is None:
|
||||||
destination = OrderedDict()
|
destination = OrderedDict()
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
destination._metadata = OrderedDict()
|
destination._metadata = OrderedDict()
|
||||||
|
|
||||||
local_metadata = dict(version=self._version)
|
local_metadata = dict(version=self._version)
|
||||||
@ -2402,7 +2410,9 @@ class Module:
|
|||||||
if k not in self._non_persistent_buffers_set
|
if k not in self._non_persistent_buffers_set
|
||||||
}
|
}
|
||||||
local_name_params = itertools.chain(
|
local_name_params = itertools.chain(
|
||||||
self._parameters.items(), persistent_buffers.items()
|
self._parameters.items(),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
persistent_buffers.items(),
|
||||||
)
|
)
|
||||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||||
|
@ -84,6 +84,7 @@ class CircularPad1d(_CircularPadNd):
|
|||||||
[5., 6., 7., 4., 5., 6., 7., 4.]]])
|
[5., 6., 7., 4., 5., 6., 7., 4.]]])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int]
|
padding: tuple[int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_2_t) -> None:
|
def __init__(self, padding: _size_2_t) -> None:
|
||||||
@ -144,6 +145,7 @@ class CircularPad2d(_CircularPadNd):
|
|||||||
[8., 6., 7., 8., 6.]]]])
|
[8., 6., 7., 8., 6.]]]])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int, int, int]
|
padding: tuple[int, int, int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_4_t) -> None:
|
def __init__(self, padding: _size_4_t) -> None:
|
||||||
@ -194,6 +196,7 @@ class CircularPad3d(_CircularPadNd):
|
|||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int, int, int, int, int]
|
padding: tuple[int, int, int, int, int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_6_t) -> None:
|
def __init__(self, padding: _size_6_t) -> None:
|
||||||
@ -265,6 +268,7 @@ class ConstantPad1d(_ConstantPadNd):
|
|||||||
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
|
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int]
|
padding: tuple[int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_2_t, value: float) -> None:
|
def __init__(self, padding: _size_2_t, value: float) -> None:
|
||||||
@ -316,6 +320,7 @@ class ConstantPad2d(_ConstantPadNd):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
__constants__ = ["padding", "value"]
|
__constants__ = ["padding", "value"]
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int, int, int]
|
padding: tuple[int, int, int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_4_t, value: float) -> None:
|
def __init__(self, padding: _size_4_t, value: float) -> None:
|
||||||
@ -356,6 +361,7 @@ class ConstantPad3d(_ConstantPadNd):
|
|||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int, int, int, int, int]
|
padding: tuple[int, int, int, int, int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_6_t, value: float) -> None:
|
def __init__(self, padding: _size_6_t, value: float) -> None:
|
||||||
@ -409,6 +415,7 @@ class ReflectionPad1d(_ReflectionPadNd):
|
|||||||
[7., 6., 5., 4., 5., 6., 7., 6.]]])
|
[7., 6., 5., 4., 5., 6., 7., 6.]]])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int]
|
padding: tuple[int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_2_t) -> None:
|
def __init__(self, padding: _size_2_t) -> None:
|
||||||
@ -462,6 +469,7 @@ class ReflectionPad2d(_ReflectionPadNd):
|
|||||||
[7., 6., 7., 8., 7.]]]])
|
[7., 6., 7., 8., 7.]]]])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int, int, int]
|
padding: tuple[int, int, int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_4_t) -> None:
|
def __init__(self, padding: _size_4_t) -> None:
|
||||||
@ -517,6 +525,7 @@ class ReflectionPad3d(_ReflectionPadNd):
|
|||||||
[1., 0., 1., 0.]]]]])
|
[1., 0., 1., 0.]]]]])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int, int, int, int, int]
|
padding: tuple[int, int, int, int, int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_6_t) -> None:
|
def __init__(self, padding: _size_6_t) -> None:
|
||||||
@ -570,6 +579,7 @@ class ReplicationPad1d(_ReplicationPadNd):
|
|||||||
[4., 4., 4., 4., 5., 6., 7., 7.]]])
|
[4., 4., 4., 4., 5., 6., 7., 7.]]])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int]
|
padding: tuple[int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_2_t) -> None:
|
def __init__(self, padding: _size_2_t) -> None:
|
||||||
@ -623,6 +633,7 @@ class ReplicationPad2d(_ReplicationPadNd):
|
|||||||
[6., 6., 7., 8., 8.]]]])
|
[6., 6., 7., 8., 8.]]]])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int, int, int]
|
padding: tuple[int, int, int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_4_t) -> None:
|
def __init__(self, padding: _size_4_t) -> None:
|
||||||
@ -665,6 +676,7 @@ class ReplicationPad3d(_ReplicationPadNd):
|
|||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
padding: tuple[int, int, int, int, int, int]
|
padding: tuple[int, int, int, int, int, int]
|
||||||
|
|
||||||
def __init__(self, padding: _size_6_t) -> None:
|
def __init__(self, padding: _size_6_t) -> None:
|
||||||
|
@ -111,6 +111,7 @@ class RNNBase(Module):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
not isinstance(dropout, numbers.Number)
|
not isinstance(dropout, numbers.Number)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
or not 0 <= dropout <= 1
|
or not 0 <= dropout <= 1
|
||||||
or isinstance(dropout, bool)
|
or isinstance(dropout, bool)
|
||||||
):
|
):
|
||||||
@ -119,6 +120,7 @@ class RNNBase(Module):
|
|||||||
"representing the probability of an element being "
|
"representing the probability of an element being "
|
||||||
"zeroed"
|
"zeroed"
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
if dropout > 0 and num_layers == 1:
|
if dropout > 0 and num_layers == 1:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"dropout option adds dropout after all but last "
|
"dropout option adds dropout after all but last "
|
||||||
@ -639,15 +641,22 @@ class RNN(RNNBase):
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
@torch._jit_internal._overload_method # noqa: F811
|
@torch._jit_internal._overload_method # noqa: F811
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(
|
def forward(
|
||||||
self, input: Tensor, hx: Optional[Tensor] = None
|
self,
|
||||||
|
input: Tensor,
|
||||||
|
hx: Optional[Tensor] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> tuple[Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@torch._jit_internal._overload_method # noqa: F811
|
@torch._jit_internal._overload_method # noqa: F811
|
||||||
def forward(
|
def forward(
|
||||||
self, input: PackedSequence, hx: Optional[Tensor] = None
|
self,
|
||||||
|
input: PackedSequence,
|
||||||
|
hx: Optional[Tensor] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> tuple[PackedSequence, Tensor]:
|
) -> tuple[PackedSequence, Tensor]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -772,7 +781,11 @@ class RNN(RNNBase):
|
|||||||
|
|
||||||
if isinstance(orig_input, PackedSequence):
|
if isinstance(orig_input, PackedSequence):
|
||||||
output_packed = PackedSequence(
|
output_packed = PackedSequence(
|
||||||
output, batch_sizes, sorted_indices, unsorted_indices
|
output,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
batch_sizes,
|
||||||
|
sorted_indices,
|
||||||
|
unsorted_indices,
|
||||||
)
|
)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
|
|
||||||
@ -996,6 +1009,7 @@ class LSTM(RNNBase):
|
|||||||
|
|
||||||
# In the future, we should prevent mypy from applying contravariance rules here.
|
# In the future, we should prevent mypy from applying contravariance rules here.
|
||||||
# See torch/nn/modules/module.py::_forward_unimplemented
|
# See torch/nn/modules/module.py::_forward_unimplemented
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def check_forward_args(
|
def check_forward_args(
|
||||||
self,
|
self,
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
@ -1029,8 +1043,12 @@ class LSTM(RNNBase):
|
|||||||
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
|
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
|
||||||
@overload # type: ignore[override]
|
@overload # type: ignore[override]
|
||||||
@torch._jit_internal._overload_method # noqa: F811
|
@torch._jit_internal._overload_method # noqa: F811
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(
|
def forward(
|
||||||
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
|
self,
|
||||||
|
input: Tensor,
|
||||||
|
hx: Optional[tuple[Tensor, Tensor]] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811
|
) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1038,7 +1056,10 @@ class LSTM(RNNBase):
|
|||||||
@overload
|
@overload
|
||||||
@torch._jit_internal._overload_method # noqa: F811
|
@torch._jit_internal._overload_method # noqa: F811
|
||||||
def forward(
|
def forward(
|
||||||
self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None
|
self,
|
||||||
|
input: PackedSequence,
|
||||||
|
hx: Optional[tuple[Tensor, Tensor]] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811
|
) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1152,7 +1173,11 @@ class LSTM(RNNBase):
|
|||||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||||
if isinstance(orig_input, PackedSequence):
|
if isinstance(orig_input, PackedSequence):
|
||||||
output_packed = PackedSequence(
|
output_packed = PackedSequence(
|
||||||
output, batch_sizes, sorted_indices, unsorted_indices
|
output,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
batch_sizes,
|
||||||
|
sorted_indices,
|
||||||
|
unsorted_indices,
|
||||||
)
|
)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
else:
|
else:
|
||||||
@ -1318,15 +1343,22 @@ class GRU(RNNBase):
|
|||||||
|
|
||||||
@overload # type: ignore[override]
|
@overload # type: ignore[override]
|
||||||
@torch._jit_internal._overload_method # noqa: F811
|
@torch._jit_internal._overload_method # noqa: F811
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(
|
def forward(
|
||||||
self, input: Tensor, hx: Optional[Tensor] = None
|
self,
|
||||||
|
input: Tensor,
|
||||||
|
hx: Optional[Tensor] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> tuple[Tensor, Tensor]: # noqa: F811
|
) -> tuple[Tensor, Tensor]: # noqa: F811
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@torch._jit_internal._overload_method # noqa: F811
|
@torch._jit_internal._overload_method # noqa: F811
|
||||||
def forward(
|
def forward(
|
||||||
self, input: PackedSequence, hx: Optional[Tensor] = None
|
self,
|
||||||
|
input: PackedSequence,
|
||||||
|
hx: Optional[Tensor] = None,
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
) -> tuple[PackedSequence, Tensor]: # noqa: F811
|
) -> tuple[PackedSequence, Tensor]: # noqa: F811
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1420,7 +1452,11 @@ class GRU(RNNBase):
|
|||||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||||
if isinstance(orig_input, PackedSequence):
|
if isinstance(orig_input, PackedSequence):
|
||||||
output_packed = PackedSequence(
|
output_packed = PackedSequence(
|
||||||
output, batch_sizes, sorted_indices, unsorted_indices
|
output,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
batch_sizes,
|
||||||
|
sorted_indices,
|
||||||
|
unsorted_indices,
|
||||||
)
|
)
|
||||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||||
else:
|
else:
|
||||||
|
@ -135,7 +135,11 @@ class Transformer(Module):
|
|||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
encoder_norm = LayerNorm(
|
encoder_norm = LayerNorm(
|
||||||
d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
|
d_model,
|
||||||
|
eps=layer_norm_eps,
|
||||||
|
bias=bias,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.encoder = TransformerEncoder(
|
self.encoder = TransformerEncoder(
|
||||||
encoder_layer, num_encoder_layers, encoder_norm
|
encoder_layer, num_encoder_layers, encoder_norm
|
||||||
@ -157,7 +161,11 @@ class Transformer(Module):
|
|||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
decoder_norm = LayerNorm(
|
decoder_norm = LayerNorm(
|
||||||
d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
|
d_model,
|
||||||
|
eps=layer_norm_eps,
|
||||||
|
bias=bias,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.decoder = TransformerDecoder(
|
self.decoder = TransformerDecoder(
|
||||||
decoder_layer, num_decoder_layers, decoder_norm
|
decoder_layer, num_decoder_layers, decoder_norm
|
||||||
@ -760,7 +768,9 @@ class TransformerEncoderLayer(Module):
|
|||||||
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
||||||
|
|
||||||
self.norm_first = norm_first
|
self.norm_first = norm_first
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||||
self.dropout1 = Dropout(dropout)
|
self.dropout1 = Dropout(dropout)
|
||||||
self.dropout2 = Dropout(dropout)
|
self.dropout2 = Dropout(dropout)
|
||||||
@ -1052,8 +1062,11 @@ class TransformerDecoderLayer(Module):
|
|||||||
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
||||||
|
|
||||||
self.norm_first = norm_first
|
self.norm_first = norm_first
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||||
self.dropout1 = Dropout(dropout)
|
self.dropout1 = Dropout(dropout)
|
||||||
self.dropout2 = Dropout(dropout)
|
self.dropout2 = Dropout(dropout)
|
||||||
|
@ -36,6 +36,7 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
if isinstance(out_size, (int, torch.SymInt)):
|
if isinstance(out_size, (int, torch.SymInt)):
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return out_size
|
return out_size
|
||||||
if len(defaults) <= len(out_size):
|
if len(defaults) <= len(out_size):
|
||||||
raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")
|
raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")
|
||||||
|
@ -43,6 +43,7 @@ def broadcast(tensor, devices=None, *, out=None):
|
|||||||
devices = [_get_device_index(d) for d in devices]
|
devices = [_get_device_index(d) for d in devices]
|
||||||
return torch._C._broadcast(tensor, devices)
|
return torch._C._broadcast(tensor, devices)
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch._C._broadcast_out(tensor, out)
|
return torch._C._broadcast_out(tensor, out)
|
||||||
|
|
||||||
|
|
||||||
@ -200,6 +201,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=
|
|||||||
"""
|
"""
|
||||||
tensor = _handle_complex(tensor)
|
tensor = _handle_complex(tensor)
|
||||||
if out is None:
|
if out is None:
|
||||||
|
# pyrefly: ignore # not-iterable
|
||||||
devices = [_get_device_index(d) for d in devices]
|
devices = [_get_device_index(d) for d in devices]
|
||||||
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
|
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
|
||||||
else:
|
else:
|
||||||
|
@ -160,6 +160,7 @@ class DataParallel(Module, Generic[T]):
|
|||||||
self.module = module
|
self.module = module
|
||||||
self.device_ids = [_get_device_index(x, True) for x in device_ids]
|
self.device_ids = [_get_device_index(x, True) for x in device_ids]
|
||||||
self.output_device = _get_device_index(output_device, True)
|
self.output_device = _get_device_index(output_device, True)
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.src_device_obj = torch.device(device_type, self.device_ids[0])
|
self.src_device_obj = torch.device(device_type, self.device_ids[0])
|
||||||
|
|
||||||
if device_type == "cuda":
|
if device_type == "cuda":
|
||||||
@ -173,6 +174,7 @@ class DataParallel(Module, Generic[T]):
|
|||||||
if not self.device_ids:
|
if not self.device_ids:
|
||||||
return self.module(*inputs, **kwargs)
|
return self.module(*inputs, **kwargs)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
for t in chain(self.module.parameters(), self.module.buffers()):
|
for t in chain(self.module.parameters(), self.module.buffers()):
|
||||||
if t.device != self.src_device_obj:
|
if t.device != self.src_device_obj:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -259,8 +261,10 @@ def data_parallel(
|
|||||||
|
|
||||||
device_ids = [_get_device_index(x, True) for x in device_ids]
|
device_ids = [_get_device_index(x, True) for x in device_ids]
|
||||||
output_device = _get_device_index(output_device, True)
|
output_device = _get_device_index(output_device, True)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
src_device_obj = torch.device(device_type, device_ids[0])
|
src_device_obj = torch.device(device_type, device_ids[0])
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
for t in chain(module.parameters(), module.buffers()):
|
for t in chain(module.parameters(), module.buffers()):
|
||||||
if t.device != src_device_obj:
|
if t.device != src_device_obj:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -241,6 +241,7 @@ class _BufferCommHook:
|
|||||||
# is completed.
|
# is completed.
|
||||||
class _DDPSink(Function):
|
class _DDPSink(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, ddp_weakref, *inputs):
|
def forward(ctx, ddp_weakref, *inputs):
|
||||||
# set_materialize_grads(False) will ensure that None gradients stay as
|
# set_materialize_grads(False) will ensure that None gradients stay as
|
||||||
# None and are not filled with zeros.
|
# None and are not filled with zeros.
|
||||||
@ -691,6 +692,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||||||
elif process_group is None and device_mesh is None:
|
elif process_group is None and device_mesh is None:
|
||||||
self.process_group = _get_default_group()
|
self.process_group = _get_default_group()
|
||||||
elif device_mesh is None:
|
elif device_mesh is None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
else:
|
else:
|
||||||
if device_mesh.ndim != 1:
|
if device_mesh.ndim != 1:
|
||||||
@ -779,11 +781,13 @@ class DistributedDataParallel(Module, Joinable):
|
|||||||
self.device_ids = None
|
self.device_ids = None
|
||||||
self.output_device = None
|
self.output_device = None
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.device_ids = [_get_device_index(x, True) for x in device_ids]
|
self.device_ids = [_get_device_index(x, True) for x in device_ids]
|
||||||
|
|
||||||
if output_device is None:
|
if output_device is None:
|
||||||
output_device = device_ids[0]
|
output_device = device_ids[0]
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.output_device = _get_device_index(output_device, True)
|
self.output_device = _get_device_index(output_device, True)
|
||||||
|
|
||||||
self.static_graph = False
|
self.static_graph = False
|
||||||
@ -933,6 +937,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||||||
# enabled.
|
# enabled.
|
||||||
self._accum_grad_hooks: list[RemovableHandle] = []
|
self._accum_grad_hooks: list[RemovableHandle] = []
|
||||||
if self._use_python_reducer:
|
if self._use_python_reducer:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
torch._inductor.config._fuse_ddp_communication = True
|
torch._inductor.config._fuse_ddp_communication = True
|
||||||
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
|
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
|
||||||
# Directly adding this to the trace rule will disturb the users
|
# Directly adding this to the trace rule will disturb the users
|
||||||
|
@ -56,12 +56,16 @@ def scatter(inputs, target_gpus, dim=0):
|
|||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
return Scatter.apply(target_gpus, None, dim, obj)
|
return Scatter.apply(target_gpus, None, dim, obj)
|
||||||
if _is_namedtuple(obj):
|
if _is_namedtuple(obj):
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
|
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
|
||||||
if isinstance(obj, tuple) and len(obj) > 0:
|
if isinstance(obj, tuple) and len(obj) > 0:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return list(zip(*map(scatter_map, obj)))
|
return list(zip(*map(scatter_map, obj)))
|
||||||
if isinstance(obj, list) and len(obj) > 0:
|
if isinstance(obj, list) and len(obj) > 0:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return [list(i) for i in zip(*map(scatter_map, obj))]
|
return [list(i) for i in zip(*map(scatter_map, obj))]
|
||||||
if isinstance(obj, dict) and len(obj) > 0:
|
if isinstance(obj, dict) and len(obj) > 0:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
|
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
|
||||||
return [obj for _ in target_gpus]
|
return [obj for _ in target_gpus]
|
||||||
|
|
||||||
@ -123,9 +127,12 @@ def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0)
|
|||||||
if isinstance(out, dict):
|
if isinstance(out, dict):
|
||||||
if not all(len(out) == len(d) for d in outputs):
|
if not all(len(out) == len(d) for d in outputs):
|
||||||
raise ValueError("All dicts must have the same number of keys")
|
raise ValueError("All dicts must have the same number of keys")
|
||||||
|
# pyrefly: ignore # not-callable
|
||||||
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
|
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
|
||||||
if _is_namedtuple(out):
|
if _is_namedtuple(out):
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return type(out)._make(map(gather_map, zip(*outputs)))
|
return type(out)._make(map(gather_map, zip(*outputs)))
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
return type(out)(map(gather_map, zip(*outputs)))
|
return type(out)(map(gather_map, zip(*outputs)))
|
||||||
|
|
||||||
# Recursive function calls like this create reference cycles.
|
# Recursive function calls like this create reference cycles.
|
||||||
|
@ -81,6 +81,7 @@ class Parameter(torch.Tensor, metaclass=_ParameterMeta):
|
|||||||
memo[id(self)] = result
|
memo[id(self)] = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "Parameter containing:\n" + super().__repr__()
|
return "Parameter containing:\n" + super().__repr__()
|
||||||
|
|
||||||
@ -143,6 +144,7 @@ class UninitializedTensorMixin:
|
|||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = self.data.dtype
|
dtype = self.data.dtype
|
||||||
self.data = torch.empty(shape, device=device, dtype=dtype)
|
self.data = torch.empty(shape, device=device, dtype=dtype)
|
||||||
|
# pyrefly: ignore # bad-override, missing-attribute
|
||||||
self.__class__ = self.cls_to_become
|
self.__class__ = self.cls_to_become
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -166,6 +168,7 @@ class UninitializedTensorMixin:
|
|||||||
|
|
||||||
def __reduce_ex__(self, proto):
|
def __reduce_ex__(self, proto):
|
||||||
# See Note [Don't serialize hooks]
|
# See Note [Don't serialize hooks]
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return (self.__class__, (self.requires_grad,))
|
return (self.__class__, (self.requires_grad,))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -175,6 +178,7 @@ class UninitializedTensorMixin:
|
|||||||
if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper":
|
if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper":
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return super().__torch_function__(func, types, args, kwargs)
|
return super().__torch_function__(func, types, args, kwargs)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Attempted to use an uninitialized parameter in {func}. "
|
f"Attempted to use an uninitialized parameter in {func}. "
|
||||||
@ -216,6 +220,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
|
|||||||
def __new__(cls, requires_grad=True, device=None, dtype=None) -> None:
|
def __new__(cls, requires_grad=True, device=None, dtype=None) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
data = torch.empty(0, **factory_kwargs)
|
data = torch.empty(0, **factory_kwargs)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
@ -261,7 +266,9 @@ class Buffer(torch.Tensor, metaclass=_BufferMeta):
|
|||||||
data = torch.empty(0)
|
data = torch.empty(0)
|
||||||
|
|
||||||
t = data.detach().requires_grad_(data.requires_grad)
|
t = data.detach().requires_grad_(data.requires_grad)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t.persistent = persistent
|
t.persistent = persistent
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._is_buffer = True
|
t._is_buffer = True
|
||||||
return t
|
return t
|
||||||
|
|
||||||
@ -292,6 +299,9 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
|
|||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
data = torch.empty(0, **factory_kwargs)
|
data = torch.empty(0, **factory_kwargs)
|
||||||
ret = torch.Tensor._make_subclass(cls, data, requires_grad)
|
ret = torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
ret.persistent = persistent
|
ret.persistent = persistent
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
ret._is_buffer = True
|
ret._is_buffer = True
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return ret
|
return ret
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from . import parametrizations, parametrize, rnn, stateless
|
from . import parametrizations, parametrize, rnn, stateless
|
||||||
from .clip_grad import (
|
from .clip_grad import ( # pyrefly: ignore # deprecated
|
||||||
_clip_grads_with_norm_ as clip_grads_with_norm_,
|
_clip_grads_with_norm_ as clip_grads_with_norm_,
|
||||||
_get_total_norm as get_total_norm,
|
_get_total_norm as get_total_norm,
|
||||||
clip_grad_norm,
|
clip_grad_norm,
|
||||||
|
@ -24,6 +24,7 @@ from .expanded_weights_utils import forward_helper
|
|||||||
@implements_per_sample_grads(F.conv3d)
|
@implements_per_sample_grads(F.conv3d)
|
||||||
class ConvPerSampleGrad(torch.autograd.Function):
|
class ConvPerSampleGrad(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(
|
def forward(
|
||||||
ctx: Any,
|
ctx: Any,
|
||||||
kwarg_names: list[str],
|
kwarg_names: list[str],
|
||||||
@ -56,6 +57,7 @@ class ConvPerSampleGrad(torch.autograd.Function):
|
|||||||
f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}"
|
f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-type-var
|
||||||
ctx.conv_fn = conv_fn
|
ctx.conv_fn = conv_fn
|
||||||
|
|
||||||
ctx.batch_size = orig_input.shape[0]
|
ctx.batch_size = orig_input.shape[0]
|
||||||
|
@ -237,6 +237,7 @@ def conv_unfold_weight_grad_sample(
|
|||||||
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
|
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
|
||||||
weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
|
weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
|
||||||
# rearrange the above tensor and extract diagonals.
|
# rearrange the above tensor and extract diagonals.
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
weight_grad_sample = weight_grad_sample.view(
|
weight_grad_sample = weight_grad_sample.view(
|
||||||
n,
|
n,
|
||||||
groups,
|
groups,
|
||||||
|
@ -14,6 +14,7 @@ from .expanded_weights_utils import (
|
|||||||
@implements_per_sample_grads(F.embedding)
|
@implements_per_sample_grads(F.embedding)
|
||||||
class EmbeddingPerSampleGrad(torch.autograd.Function):
|
class EmbeddingPerSampleGrad(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(
|
def forward(
|
||||||
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
|
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -34,6 +35,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(
|
def backward(
|
||||||
ctx: Any, grad_output: torch.Tensor
|
ctx: Any, grad_output: torch.Tensor
|
||||||
) -> tuple[Optional[torch.Tensor], ...]:
|
) -> tuple[Optional[torch.Tensor], ...]:
|
||||||
|
@ -131,7 +131,9 @@ class ExpandedWeight(torch.Tensor):
|
|||||||
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
|
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
|
||||||
decomp_opts = expanded_weights_rnn_decomps[func]
|
decomp_opts = expanded_weights_rnn_decomps[func]
|
||||||
use_input_variant = isinstance(
|
use_input_variant = isinstance(
|
||||||
args[2], list
|
# pyrefly: ignore # index-error
|
||||||
|
args[2],
|
||||||
|
list,
|
||||||
) # data variant uses a list here
|
) # data variant uses a list here
|
||||||
decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]
|
decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from .expanded_weights_impl import ExpandedWeight
|
|||||||
|
|
||||||
def is_batch_first(expanded_args_and_kwargs):
|
def is_batch_first(expanded_args_and_kwargs):
|
||||||
batch_first = None
|
batch_first = None
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for arg in expanded_args_and_kwargs:
|
for arg in expanded_args_and_kwargs:
|
||||||
if not isinstance(arg, ExpandedWeight):
|
if not isinstance(arg, ExpandedWeight):
|
||||||
continue
|
continue
|
||||||
|
@ -18,6 +18,7 @@ from .expanded_weights_utils import (
|
|||||||
@implements_per_sample_grads(F.group_norm)
|
@implements_per_sample_grads(F.group_norm)
|
||||||
class GroupNormPerSampleGrad(torch.autograd.Function):
|
class GroupNormPerSampleGrad(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||||
expanded_args, expanded_kwargs = standard_kwargs(
|
expanded_args, expanded_kwargs = standard_kwargs(
|
||||||
kwarg_names, expanded_args_and_kwargs
|
kwarg_names, expanded_args_and_kwargs
|
||||||
@ -46,6 +47,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, num_groups = ctx.input, ctx.num_groups
|
input, num_groups = ctx.input, ctx.num_groups
|
||||||
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
||||||
@ -94,7 +96,9 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
|
|||||||
set_grad_sample_if_exists(
|
set_grad_sample_if_exists(
|
||||||
weight,
|
weight,
|
||||||
lambda _: torch.einsum(
|
lambda _: torch.einsum(
|
||||||
"ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output
|
"ni...->ni",
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
|
F.group_norm(input, num_groups, eps=eps) * grad_output,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if hasattr(ctx, "bias"):
|
if hasattr(ctx, "bias"):
|
||||||
|
@ -17,6 +17,7 @@ from .expanded_weights_utils import (
|
|||||||
@implements_per_sample_grads(F.instance_norm)
|
@implements_per_sample_grads(F.instance_norm)
|
||||||
class InstanceNormPerSampleGrad(torch.autograd.Function):
|
class InstanceNormPerSampleGrad(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||||
instance_norm = partial(torch.instance_norm, cudnn_enabled=True)
|
instance_norm = partial(torch.instance_norm, cudnn_enabled=True)
|
||||||
expanded_args, expanded_kwargs = standard_kwargs(
|
expanded_args, expanded_kwargs = standard_kwargs(
|
||||||
@ -36,6 +37,7 @@ class InstanceNormPerSampleGrad(torch.autograd.Function):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
|
input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
|
||||||
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
||||||
|
@ -17,6 +17,7 @@ from .expanded_weights_utils import (
|
|||||||
@implements_per_sample_grads(F.layer_norm)
|
@implements_per_sample_grads(F.layer_norm)
|
||||||
class LayerNormPerSampleGrad(torch.autograd.Function):
|
class LayerNormPerSampleGrad(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||||
expanded_args, expanded_kwargs = standard_kwargs(
|
expanded_args, expanded_kwargs = standard_kwargs(
|
||||||
kwarg_names, expanded_args_and_kwargs
|
kwarg_names, expanded_args_and_kwargs
|
||||||
@ -42,6 +43,7 @@ class LayerNormPerSampleGrad(torch.autograd.Function):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
def weight_per_sample_grad(weight):
|
def weight_per_sample_grad(weight):
|
||||||
return sum_over_all_but_batch_and_last_n(
|
return sum_over_all_but_batch_and_last_n(
|
||||||
|
@ -16,6 +16,7 @@ from .expanded_weights_utils import (
|
|||||||
@implements_per_sample_grads(F.linear)
|
@implements_per_sample_grads(F.linear)
|
||||||
class LinearPerSampleGrad(torch.autograd.Function):
|
class LinearPerSampleGrad(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, _, __, *expanded_args_and_kwargs):
|
def forward(ctx, _, __, *expanded_args_and_kwargs):
|
||||||
if len(expanded_args_and_kwargs[0].shape) <= 1:
|
if len(expanded_args_and_kwargs[0].shape) <= 1:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -35,6 +36,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input, weight = ctx.args
|
input, weight = ctx.args
|
||||||
bias = ctx.kwargs["bias"]
|
bias = ctx.kwargs["bias"]
|
||||||
|
@ -77,6 +77,7 @@ def swap_tensor(
|
|||||||
setattr(module, name, tensor)
|
setattr(module, name, tensor)
|
||||||
elif hasattr(module, name):
|
elif hasattr(module, name):
|
||||||
delattr(module, name)
|
delattr(module, name)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return orig_tensor
|
return orig_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,9 +41,11 @@ def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
|||||||
|
|
||||||
def _no_grad_wrapper(*args, **kwargs):
|
def _no_grad_wrapper(*args, **kwargs):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# pyrefly: ignore # invalid-param-spec
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
functools.update_wrapper(_no_grad_wrapper, func)
|
functools.update_wrapper(_no_grad_wrapper, func)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return _no_grad_wrapper
|
return _no_grad_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,6 +84,7 @@ def convert_conv2d_weight_memory_format(
|
|||||||
)
|
)
|
||||||
for child in module.children():
|
for child in module.children():
|
||||||
convert_conv2d_weight_memory_format(child, memory_format)
|
convert_conv2d_weight_memory_format(child, memory_format)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
@ -163,6 +164,7 @@ def convert_conv3d_weight_memory_format(
|
|||||||
)
|
)
|
||||||
for child in module.children():
|
for child in module.children():
|
||||||
convert_conv3d_weight_memory_format(child, memory_format)
|
convert_conv3d_weight_memory_format(child, memory_format)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,6 +98,7 @@ class _Orthogonal(Module):
|
|||||||
)
|
)
|
||||||
# Q is now orthogonal (or unitary) of size (..., n, n)
|
# Q is now orthogonal (or unitary) of size (..., n, n)
|
||||||
if n != k:
|
if n != k:
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
Q = Q[..., :k]
|
Q = Q[..., :k]
|
||||||
# Q is now the size of the X (albeit perhaps transposed)
|
# Q is now the size of the X (albeit perhaps transposed)
|
||||||
else:
|
else:
|
||||||
|
@ -179,23 +179,28 @@ class ParametrizationList(ModuleList):
|
|||||||
|
|
||||||
# Register the tensor(s)
|
# Register the tensor(s)
|
||||||
if self.is_tensor:
|
if self.is_tensor:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if original.dtype != new.dtype:
|
if original.dtype != new.dtype:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When `right_inverse` outputs one tensor, it may not change the dtype.\n"
|
"When `right_inverse` outputs one tensor, it may not change the dtype.\n"
|
||||||
f"original.dtype: {original.dtype}\n"
|
f"original.dtype: {original.dtype}\n"
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
f"right_inverse(original).dtype: {new.dtype}"
|
f"right_inverse(original).dtype: {new.dtype}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if original.device != new.device:
|
if original.device != new.device:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When `right_inverse` outputs one tensor, it may not change the device.\n"
|
"When `right_inverse` outputs one tensor, it may not change the device.\n"
|
||||||
f"original.device: {original.device}\n"
|
f"original.device: {original.device}\n"
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
f"right_inverse(original).device: {new.device}"
|
f"right_inverse(original).device: {new.device}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set the original to original so that the user does not need to re-register the parameter
|
# Set the original to original so that the user does not need to re-register the parameter
|
||||||
# manually in the optimiser
|
# manually in the optimiser
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
_maybe_set(original, new)
|
_maybe_set(original, new)
|
||||||
_register_parameter_or_buffer(self, "original", original)
|
_register_parameter_or_buffer(self, "original", original)
|
||||||
else:
|
else:
|
||||||
@ -396,6 +401,7 @@ def _inject_property(module: Module, tensor_name: str) -> None:
|
|||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
raise RuntimeError("Parametrization is not working with scripting.")
|
raise RuntimeError("Parametrization is not working with scripting.")
|
||||||
parametrization = self.parametrizations[tensor_name]
|
parametrization = self.parametrizations[tensor_name]
|
||||||
|
# pyrefly: ignore # redundant-condition
|
||||||
if _cache_enabled:
|
if _cache_enabled:
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
# Scripting
|
# Scripting
|
||||||
@ -695,6 +701,7 @@ def remove_parametrizations(
|
|||||||
# Fetch the original tensor
|
# Fetch the original tensor
|
||||||
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
|
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
|
||||||
parametrizations = module.parametrizations[tensor_name]
|
parametrizations = module.parametrizations[tensor_name]
|
||||||
|
# pyrefly: ignore # invalid-argument
|
||||||
if parametrizations.is_tensor:
|
if parametrizations.is_tensor:
|
||||||
original = parametrizations.original
|
original = parametrizations.original
|
||||||
assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor"
|
assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor"
|
||||||
|
@ -274,8 +274,11 @@ class PruningContainer(BasePruningMethod):
|
|||||||
if not isinstance(args, Iterable): # only 1 item
|
if not isinstance(args, Iterable): # only 1 item
|
||||||
self._tensor_name = args._tensor_name
|
self._tensor_name = args._tensor_name
|
||||||
self.add_pruning_method(args)
|
self.add_pruning_method(args)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
elif len(args) == 1: # only 1 item in a tuple
|
elif len(args) == 1: # only 1 item in a tuple
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
self._tensor_name = args[0]._tensor_name
|
self._tensor_name = args[0]._tensor_name
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
self.add_pruning_method(args[0])
|
self.add_pruning_method(args[0])
|
||||||
else: # manual construction from list or other iterable (or no args)
|
else: # manual construction from list or other iterable (or no args)
|
||||||
for method in args:
|
for method in args:
|
||||||
@ -1097,6 +1100,7 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw
|
|||||||
|
|
||||||
# flatten importance scores to consider them all at once in global pruning
|
# flatten importance scores to consider them all at once in global pruning
|
||||||
relevant_importance_scores = torch.nn.utils.parameters_to_vector(
|
relevant_importance_scores = torch.nn.utils.parameters_to_vector(
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
[
|
[
|
||||||
importance_scores.get((module, name), getattr(module, name))
|
importance_scores.get((module, name), getattr(module, name))
|
||||||
for (module, name) in parameters
|
for (module, name) in parameters
|
||||||
|
@ -332,6 +332,7 @@ def spectral_norm(
|
|||||||
else:
|
else:
|
||||||
dim = 0
|
dim = 0
|
||||||
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,10 +41,7 @@ if not python_pytree._cxx_pytree_dynamo_traceable:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# pyrefly: ignore # import-error
|
|
||||||
import optree
|
import optree
|
||||||
|
|
||||||
# pyrefly: ignore # import-error
|
|
||||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||||
|
|
||||||
|
|
||||||
|
@ -679,7 +679,7 @@ class FlopCounterMode:
|
|||||||
|
|
||||||
|
|
||||||
import tabulate
|
import tabulate
|
||||||
# pyrefly: ignore # bad-assignment
|
|
||||||
tabulate.PRESERVE_WHITESPACE = True
|
tabulate.PRESERVE_WHITESPACE = True
|
||||||
header = ["Module", "FLOP", "% Total"]
|
header = ["Module", "FLOP", "% Total"]
|
||||||
values = []
|
values = []
|
||||||
|
@ -9,7 +9,7 @@ from typing import Any, Optional
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# pyrefly: ignore # import-error
|
|
||||||
from google.protobuf import struct_pb2
|
from google.protobuf import struct_pb2
|
||||||
|
|
||||||
from tensorboard.compat.proto.summary_pb2 import (
|
from tensorboard.compat.proto.summary_pb2 import (
|
||||||
|
@ -956,7 +956,7 @@ class SummaryWriter:
|
|||||||
)
|
)
|
||||||
self._projector_config.embeddings.extend([embedding_info])
|
self._projector_config.embeddings.extend([embedding_info])
|
||||||
|
|
||||||
# pyrefly: ignore # import-error
|
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
|
|
||||||
config_pbtxt = text_format.MessageToString(self._projector_config)
|
config_pbtxt = text_format.MessageToString(self._projector_config)
|
||||||
|
Reference in New Issue
Block a user