mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Dynamo] Automatically in-graph traceable tensor subclass ctors (#135151)
Fixes https://github.com/pytorch/pytorch/issues/114389 Previously, dynamo would attempt to trace through the `__init__` of traceable tensor subclasses, since their constructors are AOT dispatcher traceable by definition, dynamo should automatically put these in the graph like we do for any other tensors. Not doing this is difficult because dynamo would need to apply mutations post tensor subclass creation in the graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135151 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
67c7924ea1
commit
041960a1ce
@ -3,7 +3,7 @@ import contextlib
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque
|
||||
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type
|
||||
from typing_extensions import TypeGuard
|
||||
from collections import deque
|
||||
|
||||
@ -391,6 +391,11 @@ def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
|
||||
and hasattr(t, "__tensor_unflatten__")
|
||||
)
|
||||
|
||||
def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]:
|
||||
"""Same as above, but takes a type argument instead of an instance."""
|
||||
return (issubclass(t, torch.Tensor) and t != torch.Tensor
|
||||
and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))
|
||||
|
||||
|
||||
def transform_subclass(t, callback, outer_size=None, outer_stride=None):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user