mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
WIP Support python slicing with data depedennt inptu tensors maybe
ghstack-source-id: 5363cf5565c2c024260b6ae504ba12ffca2f9984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165074
This commit is contained in:
@ -528,30 +528,6 @@ Attempted to call function marked as skipped
|
|||||||
f(x)
|
f(x)
|
||||||
self.assertEqual(len(ws), 2)
|
self.assertEqual(len(ws), 2)
|
||||||
|
|
||||||
def test_slice_with_tensor(self):
|
|
||||||
def fn(x, y):
|
|
||||||
return x[:y]
|
|
||||||
|
|
||||||
self.assertExpectedInlineMunged(
|
|
||||||
Unsupported,
|
|
||||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
|
|
||||||
torch.randn(10),
|
|
||||||
torch.tensor([3]),
|
|
||||||
),
|
|
||||||
"""\
|
|
||||||
Dynamic slicing with Tensor arguments
|
|
||||||
Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
|
|
||||||
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
|
|
||||||
|
|
||||||
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None)
|
|
||||||
|
|
||||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
|
|
||||||
|
|
||||||
from user code:
|
|
||||||
File "test_error_messages.py", line N, in fn
|
|
||||||
return x[:y]""",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_observed_exception(self):
|
def test_observed_exception(self):
|
||||||
def fn():
|
def fn():
|
||||||
raise RuntimeError("test")
|
raise RuntimeError("test")
|
||||||
|
@ -3799,6 +3799,118 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
|||||||
def test_unbacked_slice_with_step_cpp_wrapper(self):
|
def test_unbacked_slice_with_step_cpp_wrapper(self):
|
||||||
self.test_unbacked_slice_with_step()
|
self.test_unbacked_slice_with_step()
|
||||||
|
|
||||||
|
@fresh_cache()
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
def test_slice_with_tensor_indices(self):
|
||||||
|
# Test slicing with tensor start/stop/step on RHS (reading)
|
||||||
|
|
||||||
|
# Test 1: Basic slice with tensor start and stop
|
||||||
|
def f1(x, start_t, stop_t):
|
||||||
|
return x[start_t:stop_t]
|
||||||
|
|
||||||
|
x = torch.randn(20)
|
||||||
|
start_t = torch.tensor(5)
|
||||||
|
stop_t = torch.tensor(15)
|
||||||
|
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
|
||||||
|
self.assertTrue(torch.allclose(fn1(x, start_t, stop_t), f1(x, start_t, stop_t)))
|
||||||
|
|
||||||
|
# Test 2: Slice with tensor step
|
||||||
|
def f2(x, start_t, stop_t, step_t):
|
||||||
|
return x[start_t:stop_t:step_t]
|
||||||
|
|
||||||
|
step_t = torch.tensor(2)
|
||||||
|
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
fn2(x, start_t, stop_t, step_t), f2(x, start_t, stop_t, step_t)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 3: Slice with only tensor start
|
||||||
|
def f3(x, start_t):
|
||||||
|
return x[start_t:]
|
||||||
|
|
||||||
|
fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
|
||||||
|
self.assertTrue(torch.allclose(fn3(x, start_t), f3(x, start_t)))
|
||||||
|
|
||||||
|
# Test 4: Slice with only tensor stop
|
||||||
|
def f4(x, stop_t):
|
||||||
|
return x[:stop_t]
|
||||||
|
|
||||||
|
fn4 = torch.compile(f4, fullgraph=True, backend="inductor")
|
||||||
|
self.assertTrue(torch.allclose(fn4(x, stop_t), f4(x, stop_t)))
|
||||||
|
|
||||||
|
# Test 5: Negative indices with tensors
|
||||||
|
def f5(x, start_t):
|
||||||
|
return x[start_t:-1]
|
||||||
|
|
||||||
|
start_t_neg = torch.tensor(-10)
|
||||||
|
fn5 = torch.compile(f5, fullgraph=True, backend="inductor")
|
||||||
|
self.assertTrue(torch.allclose(fn5(x, start_t_neg), f5(x, start_t_neg)))
|
||||||
|
|
||||||
|
# Test 6: Multidimensional slice with tensor indices
|
||||||
|
def f6(x, start_t, stop_t):
|
||||||
|
return x[:, start_t:stop_t]
|
||||||
|
|
||||||
|
x_2d = torch.randn(10, 20)
|
||||||
|
fn6 = torch.compile(f6, fullgraph=True, backend="inductor")
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(fn6(x_2d, start_t, stop_t), f6(x_2d, start_t, stop_t))
|
||||||
|
)
|
||||||
|
|
||||||
|
@fresh_cache()
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||||
|
def test_slice_with_tensor_indices_cpp_wrapper(self):
|
||||||
|
self.test_slice_with_tensor_indices()
|
||||||
|
|
||||||
|
@fresh_cache()
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
def test_select_with_tensor_index(self):
|
||||||
|
# Test direct tensor indexing (select) without calling .item()
|
||||||
|
|
||||||
|
# Test 1: Simple 0-d tensor as index
|
||||||
|
def f1(x, idx_tensor):
|
||||||
|
return x[idx_tensor]
|
||||||
|
|
||||||
|
x = torch.randn(10)
|
||||||
|
idx_tensor = torch.tensor(5)
|
||||||
|
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
|
||||||
|
self.assertTrue(torch.allclose(fn1(x, idx_tensor), f1(x, idx_tensor)))
|
||||||
|
|
||||||
|
# Test 2: Negative tensor index
|
||||||
|
def f2(x, idx_tensor):
|
||||||
|
return x[idx_tensor]
|
||||||
|
|
||||||
|
idx_tensor_neg = torch.tensor(-2)
|
||||||
|
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
|
||||||
|
self.assertTrue(torch.allclose(fn2(x, idx_tensor_neg), f2(x, idx_tensor_neg)))
|
||||||
|
|
||||||
|
# TODO support those less common patterns
|
||||||
|
# # Test 3: Multidimensional select with tensor index
|
||||||
|
# def f3(x, idx_tensor):
|
||||||
|
# return x[:, idx_tensor]
|
||||||
|
|
||||||
|
# x_2d = torch.randn(5, 10)
|
||||||
|
# fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
|
||||||
|
# self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor)))
|
||||||
|
|
||||||
|
# Test 4: Multiple tensor indices
|
||||||
|
# def f4(x, idx1, idx2):
|
||||||
|
# return x[idx1, idx2]
|
||||||
|
|
||||||
|
# x_2d = torch.randn(8, 12)
|
||||||
|
# idx1 = torch.tensor(3)
|
||||||
|
# idx2 = torch.tensor(7)
|
||||||
|
# fn4 = torch.compile(f4, fullgraph=True, backend="inductor")
|
||||||
|
# self.assertTrue(torch.allclose(fn4(x_2d, idx1, idx2), f4(x_2d, idx1, idx2)))
|
||||||
|
|
||||||
|
@fresh_cache()
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||||
|
def test_select_with_tensor_index_cpp_wrapper(self):
|
||||||
|
self.test_select_with_tensor_index()
|
||||||
|
|
||||||
@fresh_cache()
|
@fresh_cache()
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
def test_tensor_split(self):
|
def test_tensor_split(self):
|
||||||
|
@ -3138,7 +3138,7 @@ class InstructionTranslatorBase(
|
|||||||
|
|
||||||
def BUILD_SLICE(self, inst: Instruction) -> None:
|
def BUILD_SLICE(self, inst: Instruction) -> None:
|
||||||
items = self.popn(inst.argval)
|
items = self.popn(inst.argval)
|
||||||
self.push(SliceVariable(items))
|
self.push(SliceVariable(items, tx=self))
|
||||||
|
|
||||||
def BUILD_LIST(self, inst: Instruction) -> None:
|
def BUILD_LIST(self, inst: Instruction) -> None:
|
||||||
items = self.popn(inst.argval)
|
items = self.popn(inst.argval)
|
||||||
|
@ -1795,7 +1795,7 @@ class VariableBuilder:
|
|||||||
]
|
]
|
||||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
if isinstance(value, slice):
|
if isinstance(value, slice):
|
||||||
return SliceVariable(items, source=self.source)
|
return SliceVariable(items, self.tx, source=self.source)
|
||||||
else:
|
else:
|
||||||
return RangeVariable(items, source=self.source)
|
return RangeVariable(items, source=self.source)
|
||||||
|
|
||||||
|
@ -1235,6 +1235,20 @@ class BuiltinVariable(VariableTracker):
|
|||||||
# scenario is to de-sugar eagerly.
|
# scenario is to de-sugar eagerly.
|
||||||
fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]]
|
fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]]
|
||||||
|
|
||||||
|
# Convert size-1 TensorVariable indices to SymIntVariable by calling .item()
|
||||||
|
# This decomposes tensor[t] to u=t.item(); tensor[u] at the dynamo level
|
||||||
|
if (
|
||||||
|
fn is operator.getitem
|
||||||
|
and len(args) == 2
|
||||||
|
and isinstance(args[1], variables.TensorVariable)
|
||||||
|
):
|
||||||
|
tensor_idx = args[1]
|
||||||
|
# Only convert if we know it's size-1 (not for advanced indexing)
|
||||||
|
if tensor_idx.size is not None and all(s == 1 for s in tensor_idx.size):
|
||||||
|
args = list(args)
|
||||||
|
args[1] = tensor_idx.call_method(tx, "item", [], {})
|
||||||
|
args = tuple(args)
|
||||||
|
|
||||||
if fn is operator.getitem and isinstance(args[1], SymNodeVariable):
|
if fn is operator.getitem and isinstance(args[1], SymNodeVariable):
|
||||||
# Standard indexing will force specialization due to
|
# Standard indexing will force specialization due to
|
||||||
# __index__. Rewrite as a regular torch op which will
|
# __index__. Rewrite as a regular torch op which will
|
||||||
@ -1745,7 +1759,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def call_slice(self, tx: "InstructionTranslator", *args):
|
def call_slice(self, tx: "InstructionTranslator", *args):
|
||||||
return variables.SliceVariable(args)
|
return variables.SliceVariable(args, tx)
|
||||||
|
|
||||||
def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs):
|
def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||||
from .builder import wrap_fx_proxy
|
from .builder import wrap_fx_proxy
|
||||||
|
@ -1386,7 +1386,7 @@ class NamedTupleVariable(TupleVariable):
|
|||||||
|
|
||||||
|
|
||||||
class SliceVariable(VariableTracker):
|
class SliceVariable(VariableTracker):
|
||||||
def __init__(self, items, **kwargs) -> None:
|
def __init__(self, items, tx=None, **kwargs) -> None:
|
||||||
items_to_map = items
|
items_to_map = items
|
||||||
start, stop, step = [variables.ConstantVariable.create(None)] * 3
|
start, stop, step = [variables.ConstantVariable.create(None)] * 3
|
||||||
|
|
||||||
@ -1399,18 +1399,27 @@ class SliceVariable(VariableTracker):
|
|||||||
else:
|
else:
|
||||||
raise AssertionError
|
raise AssertionError
|
||||||
|
|
||||||
if isinstance(start, variables.TensorVariable) or isinstance(
|
# Convert TensorVariable to SymIntVariable by calling .item()
|
||||||
stop, variables.TensorVariable
|
# This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level
|
||||||
):
|
if isinstance(start, variables.TensorVariable):
|
||||||
unimplemented_v2(
|
assert tx is not None, (
|
||||||
gb_type="Dynamic slicing with Tensor arguments",
|
"tx is required when slice indices are TensorVariables"
|
||||||
context=f"SliceVariable start: {start}, stop: {stop}, step: {step}",
|
|
||||||
explanation="Creating slices with Tensor arguments is not supported. "
|
|
||||||
"e.g. `l[:x]`, where `x` is a 1-element tensor.",
|
|
||||||
hints=[
|
|
||||||
*graph_break_hints.SUPPORTABLE,
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
assert start.size is None or all(s == 1 for s in start.size)
|
||||||
|
start = start.call_method(tx, "item", [], {})
|
||||||
|
if isinstance(stop, variables.TensorVariable):
|
||||||
|
assert tx is not None, (
|
||||||
|
"tx is required when slice indices are TensorVariables"
|
||||||
|
)
|
||||||
|
assert stop.size is None or all(s == 1 for s in stop.size)
|
||||||
|
stop = stop.call_method(tx, "item", [], {})
|
||||||
|
if isinstance(step, variables.TensorVariable):
|
||||||
|
assert tx is not None, (
|
||||||
|
"tx is required when slice indices are TensorVariables"
|
||||||
|
)
|
||||||
|
assert step.size is None or all(s == 1 for s in step.size)
|
||||||
|
step = step.call_method(tx, "item", [], {})
|
||||||
|
|
||||||
self.items = (start, stop, step)
|
self.items = (start, stop, step)
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
Reference in New Issue
Block a user