mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
set unbacked bindings in reinplace pass for newly created nodes during generalize_scatter decomp (#164948)
Two fixes: 1. in rein_place pass, set unbacked bindings for newly created nodes. 2. In inductor, ComputeBuffer used to miss detecting some used symbols, fixed that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164948 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #164341
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6a8db0b9a
commit
017d2985f3
@ -4301,6 +4301,34 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
|||||||
accumulate(X0, torch.tensor([1])), compiled(X0, torch.tensor([1]))
|
accumulate(X0, torch.tensor([1])), compiled(X0, torch.tensor([1]))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
def test_unbacked_item_set_item3(self):
|
||||||
|
def func(x, y):
|
||||||
|
u0 = y.item()
|
||||||
|
x[u0] = 0
|
||||||
|
return x
|
||||||
|
|
||||||
|
compiled = torch.compile(func, fullgraph=True, disable=False)
|
||||||
|
b = torch.tensor([0])
|
||||||
|
a = torch.ones(9, dtype=torch.int32)
|
||||||
|
|
||||||
|
compiled(a, b)
|
||||||
|
self.assertEqual(compiled(a, b), func(a, b))
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
def test_select_scatter_unbacked_index(self):
|
||||||
|
def func(x, y):
|
||||||
|
u0 = y.item()
|
||||||
|
# Create a scalar tensor to scatter into the selected index
|
||||||
|
scalar_src = torch.tensor(42, dtype=x.dtype)
|
||||||
|
return x.select_scatter(scalar_src, 0, u0)
|
||||||
|
|
||||||
|
compiled = torch.compile(func, fullgraph=True, dynamic=True, backend="inductor")
|
||||||
|
b = torch.tensor([0])
|
||||||
|
a = torch.ones(9, dtype=torch.int32)
|
||||||
|
|
||||||
|
self.assertEqual(compiled(a, b), func(a, b))
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestUnbacked)
|
instantiate_parametrized_tests(TestUnbacked)
|
||||||
|
|
||||||
|
@ -24,7 +24,10 @@ from torch._inductor.lowering import (
|
|||||||
inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings,
|
inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings,
|
||||||
)
|
)
|
||||||
from torch._inductor.virtualized import V
|
from torch._inductor.virtualized import V
|
||||||
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
|
compute_unbacked_bindings,
|
||||||
|
GuardOnDataDependentSymNode,
|
||||||
|
)
|
||||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||||
from torch.fx.passes.reinplace import _is_view_op
|
from torch.fx.passes.reinplace import _is_view_op
|
||||||
from torch.utils import _pytree as pytree
|
from torch.utils import _pytree as pytree
|
||||||
@ -60,7 +63,9 @@ def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs):
|
|||||||
fake_result = fn(*fake_args, **fake_kwargs)
|
fake_result = fn(*fake_args, **fake_kwargs)
|
||||||
|
|
||||||
node = graph.call_function(fn, args, kwargs)
|
node = graph.call_function(fn, args, kwargs)
|
||||||
|
|
||||||
node.meta["val"] = fake_result
|
node.meta["val"] = fake_result
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
@ -171,6 +176,13 @@ def _decompose_scatter_mutating(
|
|||||||
tmp = inp
|
tmp = inp
|
||||||
for view in view_ops: # type: ignore[union-attr]
|
for view in view_ops: # type: ignore[union-attr]
|
||||||
tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr]
|
tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr]
|
||||||
|
# we need to set unbacked bindings that could have been created in the view ops.
|
||||||
|
if (V.fake_mode.shape_env) and (
|
||||||
|
symbol_to_path := compute_unbacked_bindings(
|
||||||
|
V.fake_mode.shape_env, tmp.meta["val"]
|
||||||
|
)
|
||||||
|
):
|
||||||
|
tmp.meta["unbacked_bindings"] = symbol_to_path
|
||||||
|
|
||||||
graph_call_function(graph, aten.copy_.default, tmp, src)
|
graph_call_function(graph, aten.copy_.default, tmp, src)
|
||||||
return inp # type: ignore[return-value]
|
return inp # type: ignore[return-value]
|
||||||
|
@ -4542,9 +4542,7 @@ class ComputedBuffer(OperationBuffer):
|
|||||||
unbacked_only
|
unbacked_only
|
||||||
) | self.data.get_free_symbol_uses(unbacked_only)
|
) | self.data.get_free_symbol_uses(unbacked_only)
|
||||||
|
|
||||||
if self.has_store_function() and isinstance(
|
if self.has_store_function():
|
||||||
self.get_store_function(), LoopBody
|
|
||||||
):
|
|
||||||
result |= self.get_read_writes().get_free_symbol_uses(unbacked_only)
|
result |= self.get_read_writes().get_free_symbol_uses(unbacked_only)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user