mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamic shapes] unbacked-safe slicing (#157944)
Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157944 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
0254646654
commit
56218d85e2
@ -2616,7 +2616,9 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if (
|
||||
func not in meta_table
|
||||
and not self.cpp_meta_supports_symint(func)
|
||||
and not (has_symbolic_sizes and func in self._view_fake_tensor_impl_ops)
|
||||
and not (
|
||||
has_symbolic_sizes and func in self._unbacked_special_fake_handling_ops
|
||||
)
|
||||
):
|
||||
from torch._decomp import decomposition_table
|
||||
|
||||
@ -2925,8 +2927,10 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
aten._sparse_coo_tensor_with_dims_and_tensors.default,
|
||||
)
|
||||
|
||||
_view_fake_tensor_impl_ops = ordered_set(
|
||||
aten.view.default, aten._unsafe_view.default
|
||||
_unbacked_special_fake_handling_ops = ordered_set(
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten.slice.Tensor,
|
||||
)
|
||||
|
||||
def cpp_meta_supports_symint(self, func: OpOverload) -> bool:
|
||||
|
Reference in New Issue
Block a user