mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix get_free_symbol_uses for several nodes (#160314)
get_free_symbol_uses is used to know what unbacked symbols are used by a given node. not having correct get_free_symbol_uses defined properly leads to : - eliminating of some nodes due to not detection of any users. (See the added unit test) - Incorrect topological sort. Fix get_free_symbol_uses , NopKernel , ConcarKernel, InputsKerenl, external kernel. for ComputedBuffer with NonOwningLayout its interesting case. when layout is NonOwningLayout we need to access the actual view op base layout and use detect symbols in it. Because when we codegen the ComputedBuffer we uses those symbols. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160314 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
ecde76c764
commit
96bd33b2de
@ -74,15 +74,15 @@ aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0
|
||||
|
||||
|
||||
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.1
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4495000000,0.1
|
||||
|
||||
|
||||
|
||||
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
|
||||
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8462000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9199000000,0.1
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
|
||||
|
||||
|
||||
|
||||
|
|
@ -3616,6 +3616,17 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
def test_unbacked_select_index_cpp_wrapper(self):
|
||||
self.test_unbacked_select_index()
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_select2(self):
|
||||
def f(idx, x):
|
||||
x = x.select(0, idx.item())
|
||||
return x @ x
|
||||
|
||||
x = torch.randn(3, 3, 3)
|
||||
idx = torch.tensor(1, dtype=torch.int64)
|
||||
out = torch.compile(f)(idx, x)
|
||||
self.assertEqual(out, f(idx, x))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
@ -3576,6 +3576,11 @@ class OutputSpec:
|
||||
def storage_size(self) -> int:
|
||||
raise NotImplementedError(type(self).__name__)
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
raise NotImplementedError(type(self).__name__)
|
||||
|
||||
|
||||
@ir_dataclass
|
||||
class Layout(OutputSpec):
|
||||
@ -3807,6 +3812,15 @@ class Layout(OutputSpec):
|
||||
def storage_size(self) -> Expr:
|
||||
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
return (
|
||||
get_free_symbols(self.size, unbacked_only)
|
||||
| get_free_symbols(self.stride, unbacked_only)
|
||||
| get_free_symbols(self.offset, unbacked_only)
|
||||
)
|
||||
|
||||
|
||||
class FixedLayout(Layout):
|
||||
"""A Tensor layout we cannot change"""
|
||||
@ -3999,6 +4013,16 @@ class NonOwningLayout(Layout):
|
||||
|
||||
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
assert isinstance(self.view, ReinterpretView)
|
||||
box = self.view.data
|
||||
assert isinstance(box, StorageBox), type(box)
|
||||
input_buffer = box.data
|
||||
assert isinstance(input_buffer, Buffer), type(box)
|
||||
return input_buffer.layout.get_free_symbol_uses(unbacked_only)
|
||||
|
||||
|
||||
class CommBufferType(Enum):
|
||||
SYMM_MEM = "symm_mem"
|
||||
@ -4382,6 +4406,10 @@ class ShapeAsConstantBuffer(IRNode):
|
||||
|
||||
@ir_dataclass(frozen=False)
|
||||
class ComputedBuffer(OperationBuffer):
|
||||
"""
|
||||
Represents a buffer that is computed during kernel execution rather than being an input.
|
||||
"""
|
||||
|
||||
data: Loops
|
||||
|
||||
def get_computed_buffer_name(self) -> Optional[str]:
|
||||
@ -4437,21 +4465,20 @@ class ComputedBuffer(OperationBuffer):
|
||||
# those symbols that establishes a dependency). However, we haven't
|
||||
# started codegen yet so we can't directly reuse that logic.
|
||||
#
|
||||
# For now, I'm just yoloing with the size of the buffer. Not sure if
|
||||
# it is enough.
|
||||
#
|
||||
# One thing you might wonder is if this is enough for a ComputedBuffer
|
||||
# denoting a reduction over i0. Empirically, it is enough, but for an
|
||||
# unusual reason: we only need accurate dependencies for item() call,
|
||||
# but it's impossible to end up with a reduction over i0 from an
|
||||
# item() call without a regular non-reduction buffer first.
|
||||
return (
|
||||
get_free_symbols(self.get_size(), unbacked_only)
|
||||
| get_free_symbols(self.get_stride(), unbacked_only)
|
||||
| get_free_symbols(self.get_offset(), unbacked_only)
|
||||
| self.data.get_free_symbol_uses(unbacked_only)
|
||||
| self.get_read_writes().get_free_symbol_uses(unbacked_only)
|
||||
)
|
||||
result = self.layout.get_free_symbol_uses(
|
||||
unbacked_only
|
||||
) | self.data.get_free_symbol_uses(unbacked_only)
|
||||
|
||||
if self.has_store_function() and isinstance(
|
||||
self.get_store_function(), LoopBody
|
||||
):
|
||||
result |= self.get_read_writes().get_free_symbol_uses(unbacked_only)
|
||||
return result
|
||||
|
||||
def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
|
||||
if (
|
||||
@ -4463,6 +4490,9 @@ class ComputedBuffer(OperationBuffer):
|
||||
return self.data.make_loader()
|
||||
return super().make_loader()
|
||||
|
||||
def has_store_function(self) -> bool:
|
||||
return isinstance(self.data, (Reduction, Scan, Sort, Pointwise))
|
||||
|
||||
def get_store_function(self) -> Callable[..., None]:
|
||||
indexer = self.get_layout().as_fixed().make_indexer()
|
||||
if isinstance(self.data, (Reduction, Scan, Sort)):
|
||||
@ -5170,6 +5200,18 @@ class InputsKernel(OperationBuffer):
|
||||
def num_reads(self) -> int:
|
||||
return 1
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
r = OrderedSet[sympy.Symbol]()
|
||||
for inp in self.inputs:
|
||||
if isinstance(inp, IRNode):
|
||||
r |= inp.get_free_symbol_uses(unbacked_only)
|
||||
else:
|
||||
for inner_inp in inp:
|
||||
r |= inner_inp.get_free_symbol_uses(unbacked_only)
|
||||
return r
|
||||
|
||||
|
||||
class NopKernel(InputsKernel):
|
||||
def is_no_op(self) -> bool:
|
||||
@ -5332,6 +5374,11 @@ class ConcatKernel(NopKernel):
|
||||
and not isinstance(src.data, ExternKernelAlloc)
|
||||
)
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
return NopKernel.get_free_symbol_uses(self, unbacked_only)
|
||||
|
||||
@classmethod
|
||||
def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode:
|
||||
# Attempt to turn this into a ReinterpretView rather than assert.
|
||||
@ -6232,7 +6279,7 @@ class ExternKernel(InputsKernel):
|
||||
maybe_get_symbols = (
|
||||
maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols
|
||||
)
|
||||
r = OrderedSet[sympy.Symbol]()
|
||||
r = InputsKernel.get_free_symbol_uses(self, unbacked_only)
|
||||
for arg in self.constant_args:
|
||||
r |= maybe_get_symbols(arg)
|
||||
for arg in self.kwargs.values():
|
||||
@ -8690,7 +8737,10 @@ class EffectfulKernel(FallbackKernel):
|
||||
|
||||
|
||||
class NonTensorObj(IRNode):
|
||||
pass
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
|
||||
@ir_dataclass
|
||||
|
@ -2396,11 +2396,11 @@ class Scheduler:
|
||||
for fs in s.free_symbols:
|
||||
unbacked_symbol_to_origin_node[fs] = None
|
||||
|
||||
has_non_input_unbacked_defs = False
|
||||
for node in self.nodes:
|
||||
log.debug("scheduling %s", node.node)
|
||||
assert node.node is not None
|
||||
# unbacked symbols don't follow ordinary buffer dependencies, so
|
||||
# we track their def/uses separately
|
||||
assert node.node is not None
|
||||
unbacked_symbol_defs = sorted(
|
||||
node.node.get_unbacked_symbol_defs(), key=lambda x: x.name
|
||||
)
|
||||
@ -2409,20 +2409,28 @@ class Scheduler:
|
||||
# Pick the first definer as canonical. There may be multiple
|
||||
# because if a MultiOutputLayout buffer propagates an unbacked
|
||||
# symint to multiple outputs, they will all claim to def it.
|
||||
has_non_input_unbacked_defs = True
|
||||
if s not in unbacked_symbol_to_origin_node:
|
||||
unbacked_symbol_to_origin_node[s] = node.get_name()
|
||||
|
||||
unbacked_symbol_uses = sorted(
|
||||
node.node.get_free_symbol_uses(unbacked_only=True), key=lambda x: x.name
|
||||
)
|
||||
# if a kernel takes unbacked symints, register dependencies
|
||||
for s in unbacked_symbol_uses:
|
||||
assert s in unbacked_symbol_to_origin_node, (
|
||||
f"{s} not in {unbacked_symbol_to_origin_node}"
|
||||
for node in self.nodes:
|
||||
log.debug("scheduling %s", node.node)
|
||||
|
||||
if has_non_input_unbacked_defs:
|
||||
assert node.node is not None
|
||||
|
||||
unbacked_symbol_uses = sorted(
|
||||
node.node.get_free_symbol_uses(unbacked_only=True),
|
||||
key=lambda x: x.name,
|
||||
)
|
||||
if (r := unbacked_symbol_to_origin_node[s]) is not None:
|
||||
for buf in self.name_to_node[r].get_outputs():
|
||||
node.add_fake_dep(StarDep(buf.get_name()))
|
||||
# if a kernel takes unbacked symints, register dependencies
|
||||
for s in unbacked_symbol_uses:
|
||||
assert s in unbacked_symbol_to_origin_node, (
|
||||
f"{s} not in {unbacked_symbol_to_origin_node}"
|
||||
)
|
||||
if (r := unbacked_symbol_to_origin_node[s]) is not None:
|
||||
for buf in self.name_to_node[r].get_outputs():
|
||||
node.add_fake_dep(StarDep(buf.get_name()))
|
||||
|
||||
if (
|
||||
len(node.read_writes.writes) == 1
|
||||
@ -2477,17 +2485,20 @@ class Scheduler:
|
||||
add_user(buf_name, OutputNode(StarDep(buf_name)))
|
||||
|
||||
# make sure unbacked symints aren't dead-code-eliminated
|
||||
for out in V.graph.graph_outputs:
|
||||
for s in out.get_free_symbol_uses(unbacked_only=True):
|
||||
assert s in unbacked_symbol_to_origin_node, (
|
||||
f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
|
||||
)
|
||||
if r := unbacked_symbol_to_origin_node[s]:
|
||||
for buf_name in self.name_to_node[r].get_buffer_names():
|
||||
log.debug(
|
||||
"scheduling output %s for unbacked symint %s", buf_name, s
|
||||
)
|
||||
add_user(buf_name, OutputNode(StarDep(buf_name)))
|
||||
if has_non_input_unbacked_defs:
|
||||
for out in V.graph.graph_outputs:
|
||||
for s in out.get_free_symbol_uses(unbacked_only=True):
|
||||
assert s in unbacked_symbol_to_origin_node, (
|
||||
f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
|
||||
)
|
||||
if r := unbacked_symbol_to_origin_node[s]:
|
||||
for buf_name in self.name_to_node[r].get_buffer_names():
|
||||
log.debug(
|
||||
"scheduling output %s for unbacked symint %s",
|
||||
buf_name,
|
||||
s,
|
||||
)
|
||||
add_user(buf_name, OutputNode(StarDep(buf_name)))
|
||||
|
||||
# make sure input mutation isn't dead-code-eliminated
|
||||
for name in self.mutation_renames:
|
||||
|
Reference in New Issue
Block a user