mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy. Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs. Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs. There are two major feature gaps in the implementation: - DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls. - Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something? Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236 Approved by: https://github.com/zdevito, https://github.com/albanD
88 lines
3.2 KiB
Python
88 lines
3.2 KiB
Python
# mypy: allow-untyped-defs
|
|
from enum import Enum
|
|
|
|
from torch import Tensor
|
|
|
|
# Defined in torch/csrc/functorch/init.cpp
|
|
|
|
def _set_dynamic_layer_keys_included(included: bool) -> None: ...
|
|
def get_unwrapped(tensor: Tensor) -> Tensor: ...
|
|
def is_batchedtensor(tensor: Tensor) -> bool: ...
|
|
def is_functionaltensor(tensor: Tensor) -> bool: ...
|
|
def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
|
|
def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
|
|
def is_legacy_batchedtensor(tensor: Tensor) -> bool: ...
|
|
def maybe_get_bdim(tensor: Tensor) -> int: ...
|
|
def maybe_get_level(tensor: Tensor) -> int: ...
|
|
def maybe_current_level() -> int | None: ...
|
|
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
|
|
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
|
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
|
def _unwrap_batched(tensor: Tensor, level: int) -> tuple[Tensor, int | None]: ...
|
|
def current_level() -> int: ...
|
|
def count_jvp_interpreters() -> int: ...
|
|
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
|
|
def _maybe_unsafe_set_level(tensor: Tensor, level: int) -> None: ...
|
|
def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
|
|
def get_single_level_autograd_function_allowed() -> bool: ...
|
|
def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
|
|
def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
|
|
def _vmap_increment_nesting(batch_size: int, randomness: str) -> int: ...
|
|
def _vmap_decrement_nesting() -> int: ...
|
|
def _grad_increment_nesting() -> int: ...
|
|
def _grad_decrement_nesting() -> int: ...
|
|
def _jvp_increment_nesting() -> int: ...
|
|
def _jvp_decrement_nesting() -> int: ...
|
|
|
|
# Defined in aten/src/ATen/functorch/Interpreter.h
|
|
class TransformType(Enum):
|
|
Torch = ...
|
|
Vmap = ...
|
|
Grad = ...
|
|
Jvp = ...
|
|
Functionalize = ...
|
|
|
|
class RandomnessType(Enum):
|
|
Error = ...
|
|
Same = ...
|
|
Different = ...
|
|
|
|
class CInterpreter:
|
|
def key(self) -> TransformType: ...
|
|
def level(self) -> int: ...
|
|
def serialize(self) -> bytes: ...
|
|
@staticmethod
|
|
def deserialize(bytes) -> CInterpreter: ...
|
|
|
|
class CGradInterpreterPtr:
|
|
def __init__(self, interpreter: CInterpreter) -> None: ...
|
|
def lift(self, Tensor) -> Tensor: ...
|
|
def prevGradMode(self) -> bool: ...
|
|
|
|
class CJvpInterpreterPtr:
|
|
def __init__(self, interpreter: CInterpreter) -> None: ...
|
|
def lift(self, Tensor) -> Tensor: ...
|
|
def prevFwdGradMode(self) -> bool: ...
|
|
|
|
class CFunctionalizeInterpreterPtr:
|
|
def __init__(self, interpreter: CInterpreter) -> None: ...
|
|
def key(self) -> TransformType: ...
|
|
def level(self) -> int: ...
|
|
def functionalizeAddBackViews(self) -> bool: ...
|
|
|
|
class CVmapInterpreterPtr:
|
|
def __init__(self, interpreter: CInterpreter) -> None: ...
|
|
def key(self) -> TransformType: ...
|
|
def level(self) -> int: ...
|
|
def batchSize(self) -> int: ...
|
|
def randomness(self) -> RandomnessType: ...
|
|
|
|
class DynamicLayer: ...
|
|
|
|
def get_dynamic_layer_stack_depth() -> int: ...
|
|
def get_interpreter_stack() -> list[CInterpreter]: ...
|
|
def peek_interpreter_stack() -> CInterpreter: ...
|
|
def pop_dynamic_layer_stack() -> DynamicLayer: ...
|
|
def pop_dynamic_layer_stack_and_undo_to_depth(int) -> None: ...
|
|
def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...
|