mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix get_free_symbol_uses for several nodes. (#160134)
get_free_symbol_uses is used to know what unbacked symbols are used by a given node. not having correct get_free_symbol_uses defined properly leads to : 1. eliminating of some nodes due to not detection of any users. (See the added unit test) 2. Incorrect topological sort. Fix get_free_symbol_uses , NopKernel , ConcarKernel, InputsKerenl, external kernel. for ComputedBuffer with NonOwningLayout its interesting case. when layout is NonOwningLayout we need to access the actual view op base layout and use detect symbols in it. Because when we codegen the ComputedBuffer we uses those symbols. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160134 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
29712314dd
commit
db78943a1c
@ -3616,6 +3616,17 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
def test_unbacked_select_index_cpp_wrapper(self):
|
||||
self.test_unbacked_select_index()
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_select2(self):
|
||||
def f(idx, x):
|
||||
x = x.select(0, idx.item())
|
||||
return x @ x
|
||||
|
||||
x = torch.randn(3, 3, 3)
|
||||
idx = torch.tensor(1, dtype=torch.int64)
|
||||
out = torch.compile(f)(idx, x)
|
||||
self.assertEqual(out, f(idx, x))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
@ -4443,7 +4443,8 @@ class ComputedBuffer(OperationBuffer):
|
||||
# unusual reason: we only need accurate dependencies for item() call,
|
||||
# but it's impossible to end up with a reduction over i0 from an
|
||||
# item() call without a regular non-reduction buffer first.
|
||||
return (
|
||||
|
||||
result = (
|
||||
get_free_symbols(self.get_size(), unbacked_only)
|
||||
| get_free_symbols(self.get_stride(), unbacked_only)
|
||||
| get_free_symbols(self.get_offset(), unbacked_only)
|
||||
@ -4451,6 +4452,21 @@ class ComputedBuffer(OperationBuffer):
|
||||
| self.get_read_writes().get_free_symbol_uses(unbacked_only)
|
||||
)
|
||||
|
||||
if isinstance(self.layout, NonOwningLayout):
|
||||
assert isinstance(self.layout.view, ReinterpretView)
|
||||
box = self.layout.view.data
|
||||
assert isinstance(box, StorageBox), type(box)
|
||||
input_buffer = box.data
|
||||
assert isinstance(input_buffer, Buffer), type(box)
|
||||
result = (
|
||||
result
|
||||
| get_free_symbols(input_buffer.get_size(), unbacked_only)
|
||||
| get_free_symbols(input_buffer.get_stride(), unbacked_only)
|
||||
| get_free_symbols(input_buffer.get_offset(), unbacked_only)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
|
||||
if (
|
||||
not self.get_reduction_type()
|
||||
@ -5126,6 +5142,18 @@ class InputsKernel(OperationBuffer):
|
||||
def get_reads(self) -> OrderedSet[Dep]:
|
||||
return self.get_read_writes().reads
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
r = OrderedSet[sympy.Symbol]()
|
||||
for inp in self.inputs:
|
||||
if isinstance(inp, IRNode):
|
||||
r |= inp.get_free_symbol_uses(unbacked_only)
|
||||
else:
|
||||
for inner_inp in inp:
|
||||
r |= inner_inp.get_free_symbol_uses(unbacked_only)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def unwrap_storage_for_input(cls, x: IRNode) -> IRNode:
|
||||
if isinstance(x, TensorBox):
|
||||
@ -5172,6 +5200,11 @@ class NopKernel(InputsKernel):
|
||||
def get_reads(self) -> OrderedSet[Dep]:
|
||||
return OrderedSet()
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
return InputsKernel.get_free_symbol_uses(self, unbacked_only)
|
||||
|
||||
|
||||
class ConcatKernel(NopKernel):
|
||||
"""
|
||||
@ -5326,6 +5359,11 @@ class ConcatKernel(NopKernel):
|
||||
and not isinstance(src.data, ExternKernelAlloc)
|
||||
)
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
return NopKernel.get_free_symbol_uses(self, unbacked_only)
|
||||
|
||||
@classmethod
|
||||
def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode:
|
||||
# Attempt to turn this into a ReinterpretView rather than assert.
|
||||
@ -6221,12 +6259,10 @@ class ExternKernel(InputsKernel):
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
# NB: It's not necessary to check regular inputs as we automatically
|
||||
# have dependencies on them
|
||||
maybe_get_symbols = (
|
||||
maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols
|
||||
)
|
||||
r = OrderedSet[sympy.Symbol]()
|
||||
r = InputsKernel.get_free_symbol_uses(self, unbacked_only)
|
||||
for arg in self.constant_args:
|
||||
r |= maybe_get_symbols(arg)
|
||||
for arg in self.kwargs.values():
|
||||
|
Reference in New Issue
Block a user