[Dynamo] Automatically in-graph traceable tensor subclass ctors (#135151)

Fixes https://github.com/pytorch/pytorch/issues/114389

Previously, dynamo would attempt to trace through the `__init__` of traceable tensor subclasses, since their constructors are AOT dispatcher traceable by definition, dynamo should automatically put these in the graph like we do for any other tensors. Not doing this is difficult because dynamo would need to apply mutations post tensor subclass creation in the graph.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135151
Approved by: https://github.com/bdhirsh
This commit is contained in:
Michael Lazos
2024-09-05 16:42:55 -07:00
committed by PyTorch MergeBot
parent 67c7924ea1
commit 041960a1ce
4 changed files with 109 additions and 2 deletions

View File

@ -3,7 +3,7 @@ import contextlib
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type
from typing_extensions import TypeGuard
from collections import deque
@ -391,6 +391,11 @@ def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
and hasattr(t, "__tensor_unflatten__")
)
def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]:
"""Same as above, but takes a type argument instead of an instance."""
return (issubclass(t, torch.Tensor) and t != torch.Tensor
and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))
def transform_subclass(t, callback, outer_size=None, outer_stride=None):
"""