mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add sparse compressed fake tensor support (#120920)
As in the title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120920 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
c06499981d
commit
ce2903080c
@ -8,7 +8,7 @@ from enum import Enum
|
||||
from torch.overrides import resolve_name
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq
|
||||
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
|
||||
import torch.utils._python_dispatch
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
@ -490,7 +490,9 @@ def verbose_print(e):
|
||||
return self.s
|
||||
|
||||
def go(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
if is_sparse_any(t):
|
||||
return t
|
||||
elif isinstance(t, torch.Tensor):
|
||||
return Lit(f"{t} stride={t.stride()}")
|
||||
else:
|
||||
return t
|
||||
|
Reference in New Issue
Block a user