Add sparse compressed fake tensor support (#120920)

As in the title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120920
Approved by: https://github.com/ezyang
This commit is contained in:
Pearu Peterson
2024-03-04 11:32:05 +02:00
committed by PyTorch MergeBot
parent c06499981d
commit ce2903080c
11 changed files with 148 additions and 7 deletions

View File

@ -8,7 +8,7 @@ from enum import Enum
from torch.overrides import resolve_name
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.utils import _pytree as pytree
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
import torch.utils._python_dispatch
from torch._dispatch.python import enable_python_dispatcher
from torch._ops import OpOverload, OpOverloadPacket
@ -490,7 +490,9 @@ def verbose_print(e):
return self.s
def go(t):
if isinstance(t, torch.Tensor):
if is_sparse_any(t):
return t
elif isinstance(t, torch.Tensor):
return Lit(f"{t} stride={t.stride()}")
else:
return t