diff --git a/test/export/test_export.py b/test/export/test_export.py index 6ad35b4554fe..a9365d846239 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -15967,6 +15967,28 @@ def forward(self, x, mask): ignore_empty_lines=True, ) + def test_unbacked_select_index(self): + class MyModel(torch.nn.Module): + def forward(self, x, y): + u0 = y.item() + return x.select(0, u0) + + example_inputs = ( + torch.randn((3, 3), dtype=torch.bfloat16), + torch.tensor([0]), + ) + + traced = export(MyModel(), example_inputs).run_decompositions({}) + self.assertExpectedInline( + traced.graph_module.code, + """\ +def forward(self, x, y): + item = torch.ops.aten.item.default(y); y = None + select = torch.ops.aten.select.int(x, 0, item); x = item = None + return (select,)""", + ignore_empty_lines=True, + ) + if __name__ == "__main__": run_tests() diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index af16a8f325fc..efa1875afdf4 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3529,6 +3529,88 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", ignore_empty_lines=True, ) + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_unbacked_select_index(self): + cnt = CompileCounterWithBackend("inductor") + + def func(x, y): + u0 = y.item() + return ( + torch.select(x, 0, u0), + torch.select(x, 1, u0), + torch.select(x, 2, u0), + ) + + compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) + x = torch.rand(3, 3, 3) + zero = torch.tensor([0]) + pos = torch.tensor([1]) + # code can handle both negative and positive indices. + neg = torch.tensor([-1]) + + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + self.assertEqual(compiled_func(x, zero), func(x, zero)) + output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() + self.assertExpectedInline( + output, + """\ + _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None + select: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense) + select_1: "f32[s77, s77][s77**2, 1]cpu" = torch.ops.aten.select.int(arg2_1, 1, _local_scalar_dense) + select_2: "f32[s77, s77][s77**2, s77]cpu" = torch.ops.aten.select.int(arg2_1, 2, _local_scalar_dense); arg2_1 = _local_scalar_dense = None + return (select, select_1, select_2)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + self.assertEqual(compiled_func(x, pos), func(x, pos)) + self.assertEqual(compiled_func(x, neg), func(x, neg)) + self.assertEqual(cnt.frame_count, 1) + + def func2(x, y): + u0, u1 = y.tolist() + return torch.select(x, 0, u0 + u1) + + compiled_func2 = torch.compile(fullgraph=True, backend=cnt, dynamic=False)( + func2 + ) + zero = torch.tensor([0, 0]) + pos = torch.tensor([1, 1]) + neg = torch.tensor([-1, -1]) + + self.assertEqual(compiled_func2(x, pos), func2(x, pos)) + self.assertEqual(compiled_func2(x, neg), func2(x, neg)) + self.assertEqual(compiled_func2(x, zero), func2(x, zero)) + self.assertEqual(cnt.frame_count, 2) + + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_unbacked_select_index_with_check(self): + def func3(x, y): + u0 = y.item() + # Test that taking the non-unbacked path works fine also. + torch._check(u0 >= 0) + return (torch.select(x, 1, u0),) + + compiled_func3 = torch.compile( + fullgraph=True, backend="inductor", dynamic=True + )(func3) + x = torch.rand(3, 3, 3) + zero = torch.tensor([0]) + pos = torch.tensor([1]) + print(compiled_func3(x, pos)) + + self.assertEqual(compiled_func3(x, pos), func3(x, pos)) + self.assertEqual(compiled_func3(x, zero), func3(x, zero)) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_unbacked_select_index_cpp_wrapper(self): + self.test_unbacked_select_index() + instantiate_parametrized_tests(TestUnbacked) diff --git a/torch/_export/passes/_node_metadata_hook.py b/torch/_export/passes/_node_metadata_hook.py index 41005e500973..b1195cf42128 100644 --- a/torch/_export/passes/_node_metadata_hook.py +++ b/torch/_export/passes/_node_metadata_hook.py @@ -54,6 +54,7 @@ def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None) ) }, ) + node.meta["torch_fn"] = ( f"{node.target.__name__}_0", f"{node.target.__class__.__name__}.{node.target.__name__}", diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index fa880d35366c..7dd5fdc288ac 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1452,6 +1452,20 @@ class CppWrapperCpu(PythonWrapperCodegen): # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again self.unbacked_symbol_decls.add(str(node.sym)) + def codegen_dynamic_select_index(self, node): + index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int) + + index_compute_str = ( + f"{index_cpp_str} < 0 ? {index_cpp_str} + " + f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}" + ) + self.writeline( + f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + " + f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});" + ) + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) + def make_buffer_free(self, buffer): return ( "" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 683282fa9c5a..72c2ffb555f9 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1802,6 +1802,14 @@ class PythonWrapperCodegen(CodeGen): arg_name = node.input_name(0) self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) + def codegen_dynamic_select_index(self, node): + index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}" + self.writeline( + f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})" + ) + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) + def codegen_dynamic_scalar(self, node): (data,) = (t.codegen_reference() for t in node.inputs) if len(node.keypath) == 0: diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 8a374f5bab35..835ea182f8e8 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -11,6 +11,7 @@ from unittest.mock import patch import sympy import torch +from torch._inductor.utils import get_free_symbols from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols from torch.utils._ordered_set import OrderedSet @@ -38,6 +39,12 @@ class Dep(abc.ABC): name: str index: sympy.Expr + @abc.abstractmethod + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + pass + @abc.abstractmethod def rename(self, renames: dict[str, str]) -> Self: pass @@ -70,6 +77,15 @@ class MemoryDep(Dep): size: tuple[sympy.Expr, ...] mode: Optional[str] = None + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return ( + get_free_symbols(self.index, unbacked_only) + | get_free_symbols(self.size, unbacked_only) + | get_free_symbols(self.var_names, unbacked_only) + ) + def __repr__(self) -> str: maybe_mode = "" if self.mode is not None: @@ -307,6 +323,11 @@ class StarDep(Dep): return StarDep(renames[self.name], self.mode) return self + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + def numbytes_hint(self) -> int: try: return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( @@ -349,6 +370,11 @@ class WeakDep(Dep): # This flag is used to identify those additional deps. is_fake: bool = False + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + @property def index(self) -> sympy.Expr: raise NotImplementedError("WeakDep does not have an index") @@ -446,6 +472,15 @@ class ReadWrites: names.add(dep.name) return names + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + result: OrderedSet[sympy.Symbol] = OrderedSet() + + for dep in self.reads_and_writes(): + result |= dep.get_free_symbol_uses(unbacked_only) + return result + class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index ac299d5b0c2d..660b01b69233 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -341,6 +341,7 @@ class GraphLowering(torch.fx.Interpreter): shape_env.deferred_runtime_asserts.copy() ) self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]() + self.sizevars = SizeVarAllocator(shape_env) self.graph_input_names: list[str] = [] self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {} @@ -1821,7 +1822,7 @@ class GraphLowering(torch.fx.Interpreter): shape_env = V.graph.sizevars.shape_env - # An input can an unbacked symint i.e.: when mark_unabcked is used. + # An input can be unbacked symint i.e.: when mark_unabcked is used. # in that case add it to new_unbacked_defs. if ( n.op == "placeholder" @@ -1888,6 +1889,7 @@ class GraphLowering(torch.fx.Interpreter): V.fake_mode.shape_env.unbacked_renamings.get(s, s) for s in unbacked_bindings.keys() ) + assert new_unbacked_defs >= renamed_unbacked_bindings, ( f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" f"fx node is: {n.format_node()}\n" diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a21b9c50938e..10c52eed2ecd 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -49,6 +49,7 @@ from torch._dynamo.utils import identity from torch._export.serde.serialize import GraphModuleSerializer from torch._higher_order_ops.auto_functionalize import can_auto_functionalize from torch._inductor import metrics +from torch._inductor.utils import get_free_symbols from torch._prims_common import ( compute_required_storage_length, is_boolean_dtype, @@ -62,7 +63,6 @@ from torch.fx.experimental.symbolic_shapes import ( compute_unbacked_bindings, free_symbols, free_unbacked_symbols, - IterateExprs, rebind_unbacked, resolve_unbacked_bindings, ShapeEnv, @@ -304,13 +304,6 @@ def fuse_reindexing( return reindex -def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: - if unbacked_only: - return free_unbacked_symbols(x) - else: - return free_symbols(x) - - NHWC_STRIDE_ORDER = [3, 0, 2, 1] NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] @@ -4377,6 +4370,13 @@ class ComputedBuffer(OperationBuffer): return self.data.get_read_names() def get_read_writes(self) -> dependencies.ReadWrites: + if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)): + return dependencies.ReadWrites( + reads=OrderedSet(), + writes=OrderedSet(), + index_exprs=OrderedSet(), + ) + with patch.object(FlexibleLayout, "allow_indexing", True): if self.data.get_reduction_type(): return extract_read_writes( @@ -4415,6 +4415,7 @@ class ComputedBuffer(OperationBuffer): | get_free_symbols(self.get_stride(), unbacked_only) | get_free_symbols(self.get_offset(), unbacked_only) | self.data.get_free_symbol_uses(unbacked_only) + | self.get_read_writes().get_free_symbol_uses(unbacked_only) ) def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: @@ -7023,6 +7024,50 @@ class DeviceCopy(ExternKernelOut): wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1]) +class DynamicSelectStorageOffset(ExternKernel): + """ + The result of computing a dynamic selection index is determined as follows: when the index in the + select operation is unbacked, the actual index calculation is ambiguous for negative indices + (index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked + SymInt to represent the storage offset and decompose the select operation into a call to as_strided, + computing the storage offset at runtime with this node. + """ + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + def should_allocate(self) -> bool: + return False + + def __init__( + self, + unbacked_offset_symbol: sympy.Symbol, + index: sympy.Symbol, + base_offset: Union[sympy.Symbol, int], + base_dim_stride: Union[sympy.Symbol, int], + size: Union[sympy.Symbol, int], + ) -> None: + super().__init__(None, NoneLayout(device=torch.device("cpu")), []) + # This node codegen the following: + # unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size) + self.unbacked_offset_symbol = unbacked_offset_symbol + self.index = index + self.base_offset = base_offset + self.base_dim_stride = base_dim_stride + self.size = size + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet([self.unbacked_offset_symbol]) + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return get_free_symbols(self.index, unbacked_only) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.codegen_dynamic_select_index(self) + + class DynamicScalar(ExternKernel): """ The result of a call to aten._local_scalar_dense. diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 503795bc513c..5da93369672d 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -40,7 +40,11 @@ from torch._prims_common import ( Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.fx.experimental.symbolic_shapes import ( + free_unbacked_symbols, + has_free_unbacked_symbols, + resolve_unbacked_bindings, +) from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing @@ -978,10 +982,7 @@ def squeeze(x, dim=None): new_shape = [] for d, s in enumerate(x.get_size()): - if not ( - d in dims - and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True) - ): + if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))): new_shape.append(s) # squeeze does nothing if the size isn't 1 @@ -1747,8 +1748,60 @@ def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): @register_lowering(aten.select, type_promotion_kind=None) def select(x, dim, idx): - idx = View.handle_negative_index(idx, x.get_size()[dim]) - return squeeze(slice_(x, dim, idx, idx + 1), dim) + idx = sympy.expand(idx) + size = sympy.expand(x.get_size()[dim]) + actual_index = None + + if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)): + actual_index = idx + size + elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)): + actual_index = idx + + if actual_index is not None: + if has_free_unbacked_symbols(idx): + # Inductor could generate incorrect views for tensors with unbacked symbols here; + # Squeeze operations are translated to views, resulting in incorrect strides. + # Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this, + # we use as_strided instead. + # Removing this branch will cause test_unbacked_select_index_with_check to fail. + new_size = x.get_size() + new_stride = x.get_stride() + new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index + + del new_size[dim] + del new_stride[dim] + return as_strided(x, new_size, new_stride, new_storage_offset) + else: + slice_result = slice_(x, dim, actual_index, actual_index + 1) + return squeeze(slice_result, dim) + + # Unbacked Semantics: + # When the index idx is unbacked (e.g., u0), we compute the index dynamically + # during the lowering of the select operation using DynamicSelectStorageOffset. + + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] + ) + assert unbacked_bindings is not None + assert len(unbacked_bindings) == 1, unbacked_bindings + unbacked_offset_sym, _ = next(iter(unbacked_bindings.items())) + + new_size = x.get_size() + new_stride = x.get_stride() + new_storage_offset = unbacked_offset_sym + buffer = ir.DynamicSelectStorageOffset( + unbacked_offset_sym, + idx, + x.get_layout().offset, + new_stride[dim], + x.get_size()[dim], + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + del new_size[dim] + del new_stride[dim] + return as_strided(x, new_size, new_stride, new_storage_offset) @register_lowering(aten.split, type_promotion_kind=None) @@ -3074,8 +3127,6 @@ def long_tensor(data): @register_lowering(aten._local_scalar_dense) def _local_scalar_dense(data): - from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings - # This is interesting! Most lowerings return tensors, so you can just # return the buffer you allocated and it will get used (or not used, if # it's dead.) But _local_scalar_dense (aka item) returns an int, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index f17876e152e3..5d4ce3e4f58f 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2130,9 +2130,11 @@ class Scheduler: self.logged_slow_fusion = OrderedSet[tuple[str, str]]() if config._pre_fusion_custom_pass is not None: self.nodes = config._pre_fusion_custom_pass(self.nodes) + self.nodes = self.fuse_nodes(self.nodes) if config._post_fusion_custom_pass is not None: self.nodes = config._post_fusion_custom_pass(self.nodes) + self.merge_loops() self.finalize_multi_template_buffers() if config.combo_kernels: @@ -2383,7 +2385,6 @@ class Scheduler: for node in self.nodes: log.debug("scheduling %s", node.node) - # unbacked symbols don't follow ordinary buffer dependencies, so # we track their def/uses separately assert node.node is not None diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 1163144498bd..c0e29bf27a05 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -69,13 +69,20 @@ OPTIMUS_EXCLUDE_POST_GRAD = [ "inductor_autotune_lookup_table", ] +from torch.fx.experimental.symbolic_shapes import ( + free_symbols, + free_unbacked_symbols, + IterateExprs, + ShapeEnv, +) + + if TYPE_CHECKING: from collections.abc import Iterable, Sequence, ValuesView from torch import SymBool, SymFloat, SymInt from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.fx import GraphModule - from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.fx.node import Node from .codegen.common import WorkspaceArg @@ -3355,3 +3362,10 @@ def aoti_model_name_from_config() -> str: model_name = config.aot_inductor.model_name_for_generated_files model_name = "aoti_model" if model_name is None else model_name return model_name + + +def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: + if unbacked_only: + return free_unbacked_symbols(x) + else: + return free_symbols(x) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index ae87e0e17fb3..2933a37c37fd 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5553,39 +5553,6 @@ def meta_zeros( ) -@register_meta(aten.select.int) -def meta_select(self, dim, index): - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - - ndim = self.dim() - torch._check_index( - ndim != 0, - lambda: "select() cannot be applied to a 0-dim tensor.", - ) - - dim = dim if dim >= 0 else dim + ndim - size = self.size(dim) - - torch._check_index( - not ( - guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size) - ), - lambda: f"select(): index {index} out of range for tensor of size " - f"{self.size()} at dimension {dim}", - ) - - index = index if index >= 0 else index + size - - new_size = list(self.size()) - new_stride = list(self.stride()) - - new_storage_offset = self.storage_offset() + index * new_stride[dim] - del new_size[dim] - del new_stride[dim] - - return self.as_strided(new_size, new_stride, new_storage_offset) - - @register_meta(aten.select_scatter.default) def meta_select_scatter(self, src, dim, index): return utils.clone_preserve_strides(self) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index e802d9a4389d..e2e24cb59bc2 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -359,6 +359,48 @@ def unique2( return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) +@register_op_impl(aten.select.int) +def meta_select(fake_mode, func, self, dim, index): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if self.is_sparse: + return NotImplemented + + ndim = self.dim() + torch._check_index( + ndim != 0, + lambda: "select() cannot be applied to a 0-dim tensor.", + ) + + dim = dim if dim >= 0 else dim + ndim + size = self.size(dim) + + new_size = list(self.size()) + new_stride = list(self.stride()) + + new_storage_offset = None + if guard_or_false(index >= 0): + new_storage_offset = self.storage_offset() + index * new_stride[dim] + elif guard_or_false(index < 0): + new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim] + + if new_storage_offset is None: + if fake_mode.shape_env is None or ( + not fake_mode.shape_env.allow_scalar_outputs + and not fake_mode.allow_scalar_outputs + ): + raise DataDependentOutputException(func) + + # index is data-dependent, we do not know which index we are accessing it could be index or index+size! + # we assign a new data-dependent symbol for the storage offset. + new_storage_offset = fake_mode.shape_env.create_unbacked_symint() + + del new_size[dim] + del new_stride[dim] + assert new_storage_offset is not None + return self.as_strided(new_size, new_stride, new_storage_offset) + + @register_op_impl(aten.unique_dim.default) def unique_dim( fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index c4eb239437bc..375019f7dc83 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1282,6 +1282,7 @@ def compute_unbacked_bindings( return None fs = shape_env.pending_fresh_unbacked_symbols + pending = set(fs) if not pending: return None @@ -4809,6 +4810,7 @@ class ShapeEnv: ) self.counter["create_unbacked_symbol"] += 1 if not self._ignore_fresh_unbacked_symbols_tls(): + print(f"adding {symbol}") self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 38c64c527aff..bb71a25971da 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -461,6 +461,7 @@ def insert_deferred_runtime_asserts( ), keypath[2:], ) + return go( graph.call_method( keypath[0].name, (node, keypath[1].idx) @@ -468,6 +469,15 @@ def insert_deferred_runtime_asserts( keypath[2:], ) elif isinstance(keypath[0], CallMethodKey): + if keypath[0].name == "storage_offset": + return go( + graph.call_function( + torch.ops.aten.sym_storage_offset.default, + (node,), + ), + keypath[1:], + ) + return go( graph.call_method(keypath[0].name, (node,)), keypath[1:] )