Implementation of scan (#134102)

This operation is supposed to be the pendant to the `associative_scan`, but can operate with non-associative functions.

@ydwu4

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134102
Approved by: https://github.com/ydwu4
This commit is contained in:
Thomas Bohnstingl
2024-09-10 04:51:16 +00:00
committed by PyTorch MergeBot
parent 6546c6186d
commit e889252493
8 changed files with 2376 additions and 231 deletions

File diff suppressed because it is too large Load Diff

View File

@ -23,6 +23,7 @@ from torch import distributed as dist
from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor
from torch._dynamo.testing import make_test_cls_with_patches, rand_strided
from torch._guards import tracing, TracingContext
from torch._higher_order_ops.scan import scan
from torch._subclasses.fake_tensor import (
DynamicOutputShapeException,
extract_tensor_metadata,
@ -923,6 +924,21 @@ class FakeTensorTest(TestCase):
self.assertIsInstance(r, FakeTensor)
self.assertEqual(r.size(), [3])
@parametrize("reverse", [False, True])
def test_scan(self, reverse):
def add(x, y):
return x + y, x + y
with torch._subclasses.fake_tensor.FakeTensorMode():
x = torch.randn((3, 5, 7), device="cpu")
init = torch.randn((3, 1, 7), device="cpu")
r = scan(add, init, x, dim=1, reverse=reverse)
self.assertIsInstance(r[0], FakeTensor)
self.assertIsInstance(r[1], FakeTensor)
self.assertEqual(r[0].size(), init.size())
self.assertEqual(r[1].size(), x.size())
instantiate_parametrized_tests(FakeTensorTest)

View File

@ -3188,6 +3188,7 @@ LEGACY_MOD_INLINELIST = {
"torch._higher_order_ops.cond",
"torch._higher_order_ops.while_loop",
"torch._higher_order_ops.associative_scan",
"torch._higher_order_ops.scan",
"torch.nn.attention.flex_attention",
"torch.ao.quantization.pt2e.export_utils",
"torch.ao.quantization.pt2e.qat_utils",
@ -3228,6 +3229,7 @@ MOD_INLINELIST = [
"torch._functorch.functional_call",
"torch._functorch.vmap",
"torch._higher_order_ops.associative_scan",
"torch._higher_order_ops.scan",
"torch._higher_order_ops.strict_mode",
"torch._higher_order_ops.while_loop",
"torch._inductor.test_operators",

View File

@ -602,6 +602,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
return RunWithRNGStateHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "associative_scan":
return AssociativeScanHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "scan":
return ScanHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "call_torchbind":
return CallTorchbindHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "wrap_with_set_grad_enabled":
@ -1022,16 +1024,16 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
def arg_extractor(combine_fn, input, dim):
return combine_fn, input, dim
def arg_extractor(combine_fn, xs, dim):
return combine_fn, xs, dim
combine_fn, input, dim = arg_extractor(*args, **kwargs)
combine_fn, xs, dim = arg_extractor(*args, **kwargs)
if input.python_type() != list:
if xs.python_type() != list:
unimplemented(
f"Expected input to be a list of tensors but got {input.python_type()}",
f"Expected xs to be a list of tensors but got {xs.python_type()}",
)
assert isinstance(input, torch._dynamo.variables.lists.BaseListVariable)
assert isinstance(xs, torch._dynamo.variables.lists.BaseListVariable)
# Trace the subgraph
# TODO: Fix these pointless new_empty calls appearing in the dynamo output graph.
@ -1054,7 +1056,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
"requires_grad": SourcelessBuilder.create(tx, leaf.requires_grad),
},
)
for leaf in itertools.chain(input.items, input.items)
for leaf in itertools.chain(xs.items, xs.items)
]
(
(combine_result, combine_treespec),
@ -1065,7 +1067,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
combine_fn,
sub_args,
sub_kwargs={},
description="scan_combine",
description="associative_scan_combine_fn",
source_target=self.value,
set_subgraph_inputs="flatten_manual",
)
@ -1080,9 +1082,9 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}",
)
input_proxy = input.as_proxy()
xs_proxy = xs.as_proxy()
combine_result_proxy = combine_result.as_proxy()
for result, inp_proxy in zip(combine_result_proxy, input_proxy):
for result, inp_proxy in zip(combine_result_proxy, xs_proxy):
inp_meta = inp_proxy.node.meta["example_value"]
combine_result_meta = result.node.meta["example_value"]
if combine_result_meta.device != inp_meta.device:
@ -1097,18 +1099,17 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
combine_fn_name = add_subgraph(tx, "scan_combine", combine_gm)
combine_fn_name = add_subgraph(tx, "associative_scan_combine_fn", combine_gm)
p_args = (
make_attr(tx, combine_fn_name),
input_proxy,
xs_proxy,
dim.as_proxy(),
)
with tx.fake_mode:
out_meta = tuple(
inp_proxy.node.meta["example_value"].clone()
for inp_proxy in input_proxy
inp_proxy.node.meta["example_value"].clone() for inp_proxy in xs_proxy
)
return wrap_fx_proxy(
tx=tx,
@ -1119,6 +1120,196 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
@raise_hard_error_if_graph_break(
reason="scan must be captured completely with torch.compile."
)
def call_function(
self,
tx: "InstructionTranslator",
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
) -> VariableTracker:
from torch._higher_order_ops.scan import make_expanded_output_shape
from .builder import SourcelessBuilder, wrap_fx_proxy
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
def arg_extractor(combine_fn, init, xs, dim, reverse):
return combine_fn, init, xs, dim, reverse
combine_fn, init, xs, dim, reverse = arg_extractor(*args, **kwargs)
if xs.python_type() != list:
unimplemented(
f"Expected xs to be a list of tensors but got {xs.python_type()}",
)
assert isinstance(xs, torch._dynamo.variables.lists.BaseListVariable)
if init.python_type() != list:
unimplemented(
f"Expected init to be a list of tensors but got {init.python_type()}",
)
assert isinstance(init, torch._dynamo.variables.lists.BaseListVariable)
dim_fake = (
dim.as_proxy()
if type(dim.as_proxy()) == int
else get_fake_value(dim.as_proxy().node, tx)
)
scan_length = get_fake_value(xs.items[0].as_proxy().node, tx).size()[dim_fake]
if scan_length == 0:
unimplemented(
"scan() operator doesn't support zero-sized tensors during tracing."
)
init_len = len(init.items)
if init_len == 0:
unimplemented("scan() operator requires init leaves.")
# Trace the subgraph
# TODO: Fix these pointless new_empty calls appearing in the dynamo output graph.
# TODO: Unify handling of sub_args across control flow ops, such as cond, while_loop, etc.
sub_args_init = [
ini.call_method(
tx,
"new_empty",
args=(
SourcelessBuilder.create(
tx,
ini.size
if ini.size is not None
else tuple(
BuiltinVariable(getattr)
.call_function(
tx, [ini, ConstantVariable.create("shape")], {}
)
.items
),
),
),
kwargs={
"dtype": SourcelessBuilder.create(tx, ini.dtype),
"device": SourcelessBuilder.create(tx, ini.device),
"requires_grad": SourcelessBuilder.create(tx, ini.requires_grad),
},
)
for ini in init.items
]
sub_args_inp_shapes = make_expanded_output_shape(
dim_fake,
1,
[
tuple(
BuiltinVariable(getattr)
.call_function(tx, [inp, ConstantVariable.create("shape")], {})
.items
)
for inp in xs.items
],
True,
)
sub_args_inp = [
inp.call_method(
tx,
"new_empty",
args=(SourcelessBuilder.create(tx, inp_sh),),
kwargs={
"dtype": SourcelessBuilder.create(tx, inp.dtype),
"device": SourcelessBuilder.create(tx, inp.device),
"requires_grad": SourcelessBuilder.create(tx, inp.requires_grad),
},
)
for inp, inp_sh in zip(xs.items, sub_args_inp_shapes)
]
sub_args = sub_args_init + sub_args_inp
(
(combine_result, combine_treespec),
combine_graph,
combine_lifted_freevars,
) = speculate_subgraph(
tx,
combine_fn,
sub_args,
sub_kwargs={},
description="scan_combine_fn",
source_target=self.value,
set_subgraph_inputs="flatten_manual",
)
if combine_lifted_freevars:
unimplemented(
f"Combine fn had unexpected freevars: {combine_lifted_freevars}"
)
if any(cr.python_type() != list for cr in combine_result.items):
unimplemented(
f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}",
)
xs_proxy = xs.as_proxy()
init_proxy = init.as_proxy()
combine_carry_proxy = combine_result.items[0].as_proxy()
# Checks for carry and init
for ini_proxy, carry in zip(init_proxy, combine_carry_proxy):
ini_meta = ini_proxy.node.meta["example_value"]
carry_meta = carry.node.meta["example_value"]
if (
carry_meta.device != ini_meta.device
or carry_meta.dtype != ini_meta.dtype
or carry_meta.shape != ini_meta.shape
):
unimplemented(
f"Expected metadata of the combine_fn result {carry_meta} to be the same as "
+ f"the metadata of init with {ini_meta}"
)
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
combine_fn_name = add_subgraph(tx, "scan_combine_fn", combine_gm)
p_args = (
make_attr(tx, combine_fn_name),
init_proxy,
xs_proxy,
dim.as_proxy(),
reverse.as_proxy(),
)
with tx.fake_mode:
# For the fake mode, we need to duplicate the init tensor along the dim
# to have the same size as the xs arguments
# We also do a clone with contiguous_format. This is to be consistent with
# eager semantic of map, which stacks the outputs. The result is contiguous
# as a result of the stack operation.
fake_out_shapes = make_expanded_output_shape(
dim_fake,
scan_length,
[
get_fake_value(o.as_proxy().node, tx).size()
for o in combine_result.items[1].items
],
)
out_meta = (
[init_p.node.meta["example_value"].clone() for init_p in init_proxy],
list( # noqa: C400
t.as_proxy()
.node.meta["example_value"]
.expand(*sh)
.clone(memory_format=torch.contiguous_format)
for t, sh in zip(combine_result.items[1].items, fake_out_shapes)
),
)
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function", torch.ops.higher_order.scan, p_args, {}
),
example_value=out_meta,
)
def non_single_tensor_return_unsupported(api, ret):
from . import TensorVariable

View File

@ -73,8 +73,8 @@ class AssociativeScanOp(HigherOrderOperator):
def __init__(self):
super().__init__("associative_scan")
def __call__(self, combine_fn, input, dim):
return super().__call__(combine_fn, input, dim)
def __call__(self, combine_fn, xs, dim):
return super().__call__(combine_fn, xs, dim)
associative_scan_op = AssociativeScanOp()
@ -82,13 +82,13 @@ associative_scan_op = AssociativeScanOp()
def associative_scan(
combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree],
input: pytree.PyTree,
xs: pytree.PyTree,
dim: int,
reverse: bool = False,
combine_mode: str = "pointwise",
) -> torch.Tensor:
r"""
Performs an inclusive scan with an associative pointwise combine function.
Performs an inclusive scan with an associative combine function.
.. warning::
`torch.associative_scan` is a prototype feature in PyTorch. It currently
@ -102,14 +102,15 @@ def associative_scan(
Args:
combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``,
or if input is a pytree ``(pytree, pytree) -> pytree``.
This function must be pure, pointwise, and satisfy the associative property.
input (torch.Tensor): The input tensor, or nested pytree of tensors.
This function must be pure, i.e., no lifted arguments are supported at the moment,
satisfy the associative property and have no side-effects.
xs (torch.Tensor): The input tensor, or nested pytree of tensors.
All inputs are expected to have the same shape.
dim (int): the dimension to scan over
reverse (bool): A boolean stating if the scan should be reversed with respect to the dimension.
combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``.
reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``.
combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``, default ``pointwise``.
If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations
and ``input`` must be CUDA tensors.
and ``xs`` must be CUDA tensors.
In all other cases ``combine_mode=generic`` should be used.
Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``.
@ -122,27 +123,32 @@ def associative_scan(
cumsum = associative_scan(add, x, dim)
"""
assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}"
assert isinstance(dim, int), "dim must be an int, but got {type(dim)}"
assert combine_mode in ["pointwise", "generic"]
if not callable(combine_fn):
raise RuntimeError("Combine_fn must be a callable, but got {combine_fn}")
if not isinstance(dim, int):
raise RuntimeError("Dim must be an int, but got " + str(type(dim)))
if combine_mode not in ["pointwise", "generic"]:
raise RuntimeError(
"Combine_mode must either 'pointwise' or 'generic', but got {combine_mode}"
)
if not torch._dynamo.is_compiling():
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
return torch.compile(associative_scan, fullgraph=True)(
combine_fn, input, dim, reverse=reverse, combine_mode=combine_mode
combine_fn, xs, dim, reverse=reverse, combine_mode=combine_mode
)
leaves, spec = pytree.tree_flatten(input)
leaves, spec = pytree.tree_flatten(xs)
if combine_mode == "pointwise" and not all(l.device.type == "cuda" for l in leaves):
raise ValueError(
"For combine_mode='pointwise', all input tensors need to be on CUDA"
)
assert len(leaves) >= 1, "expected at least 1 input leaf"
assert all(
isinstance(x, torch.Tensor) for x in leaves
), "input leaves must be a Tensor"
if len(leaves) == 0:
raise RuntimeError("Expected at least 1 xs leaf")
if any(not isinstance(x, torch.Tensor) for x in leaves):
raise RuntimeError("xs leaves must be a Tensor")
if reverse:
leaves = [torch.flip(elem, [dim]) for elem in leaves]
@ -152,20 +158,21 @@ def associative_scan(
dim = utils.canonicalize_dim(ndim, dim)
for x in leaves[1:]:
assert x.shape == shape, "All input tensors must have the same shape"
assert x.shape == shape, "All xs tensors must have the same shape"
out = combine_fn(
pytree.tree_unflatten(leaves, spec),
pytree.tree_unflatten(leaves, spec),
)
out_leaves, tree_out = pytree.tree_flatten(out)
assert len(leaves) == len(
out_leaves
), "The pytree of the output of the operator needs to match the input pytree"
for x in out_leaves:
assert (
x.shape == shape
), "The pytree of the output of the operator needs to match the input pytree"
if len(leaves) != len(out_leaves):
raise RuntimeError(
"The number of leaves of the pytree of the output of the operator needs to match the length of the pytree of the input"
)
if any(x.shape != shape for x in out_leaves):
raise RuntimeError(
"The pytree of the output of the operator needs to match the xs pytree"
)
combine_fn = functools.partial(
wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(leaves)
@ -194,7 +201,7 @@ def generic_associative_scan(operator, elems_flat, dim=0):
or if input is a pytree ``(pytree, pytree) -> pytree``.
This function must be pure, pointwise, and satisfy the associative property.
elems_flat (torch.Tensor): A list of torch.Tensors converted from the pytree of
``input`` provided to ``associative_scan``.
``xs`` provided to ``associative_scan``.
All inputs are expected to have the same shape.
dim (int): the dimension to scan over
@ -279,19 +286,19 @@ def generic_associative_scan(operator, elems_flat, dim=0):
def trace_associative_scan(
proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int
proxy_mode, func_overload, combine_fn: Callable, xs: List[torch.Tensor], dim: int
):
with disable_proxy_modes_tracing():
sample_inputs = [
sample_xs = [
torch.empty_like(
x,
dtype=x.dtype,
device=x.device,
requires_grad=x.requires_grad,
)
for x in itertools.chain(input, input)
for x in itertools.chain(xs, xs)
]
combine_graph = reenter_make_fx(combine_fn)(*sample_inputs)
combine_graph = reenter_make_fx(combine_fn)(*sample_xs)
outputs = None
for node in combine_graph.graph.nodes:
@ -307,10 +314,10 @@ def trace_associative_scan(
assert outputs is not None
assert len(outputs) == len(
input
), f"expected combine_fn to return {len(input)} results but got {len(outputs)}"
xs
), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}"
for i, o in zip(input, outputs):
for i, o in zip(xs, outputs):
o_meta = o.meta["tensor_meta"]
assert o_meta.dtype == i.dtype, (
f"combine_fn output type mismatch, expected {i.dtype} "
@ -321,20 +328,20 @@ def trace_associative_scan(
proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)
args = (combine_graph, input, dim)
args = (combine_graph, xs, dim)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="associative_scan"
)
with disable_proxy_modes_tracing():
out = [aten.clone(x) for x in input]
out = [aten.clone(x) for x in xs]
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def associative_scan_op_dense(combine_fn, input, dim):
def associative_scan_op_dense(combine_fn, xs, dim):
raise NotImplementedError("associative_scan is not implemented for eager")
@ -344,22 +351,22 @@ associative_scan_op.py_impl(DispatchKey.Autograd)(
@associative_scan_op.py_impl(ProxyTorchDispatchMode)
def associative_scan_proxy_mode(mode, combine_fn, input, dim):
return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim)
def associative_scan_proxy_mode(mode, combine_fn, xs, dim):
return trace_associative_scan(mode, associative_scan_op, combine_fn, xs, dim)
@associative_scan_op.py_impl(FakeTensorMode)
def assoiciative_scan_fake_tensor_mode(mode, combine_fn, input, dim):
def assoiciative_scan_fake_tensor_mode(mode, combine_fn, xs, dim):
with mode:
return [x.clone() for x in input]
return [x.clone() for x in xs]
@associative_scan_op.py_functionalize_impl
def associative_scan_functionalize(ctx, combine_fn, input, dim):
unwrapped_input = ctx.unwrap_tensors(input)
def associative_scan_functionalize(ctx, combine_fn, xs, dim):
unwrapped_xs = ctx.unwrap_tensors(xs)
with ctx.redispatch_to_next() as m:
functional_combine_fn = ctx.functionalize(
_maybe_run_with_interpreter(combine_fn)
)
ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim)
ret = associative_scan_op(functional_combine_fn, unwrapped_xs, dim)
return ctx.wrap_tensors(ret)

View File

@ -0,0 +1,438 @@
# mypy: allow-untyped-defs
import functools
import itertools
from typing import Callable, List, Tuple
import torch
import torch._prims_common as utils
import torch._subclasses.functional_tensor
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_set_compilation_env,
autograd_not_implemented,
reenter_make_fx,
unique_graph_id,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.utils._python_dispatch import _get_current_dispatch_mode
aten = torch._ops.ops.aten
def wrap_combine_fn_flat(
*args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves
):
assert len(args) == (num_init_leaves + num_inp_leaves)
carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init)
xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs)
carry, combined = combine_fn(carry, xs)
carry_flat = pytree.tree_leaves(carry)
combined_flat = pytree.tree_leaves(combined)
assert num_init_leaves == len(carry_flat)
return (carry_flat, combined_flat)
def scan(
combine_fn: Callable[
[pytree.PyTree, pytree.PyTree], Tuple[pytree.PyTree, pytree.PyTree]
],
init: pytree.PyTree,
xs: pytree.PyTree,
/,
*,
dim: int = 0,
reverse: bool = False,
) -> Tuple[pytree.PyTree, pytree.PyTree]:
r"""
Performs an inclusive scan with a combine function.
.. warning::
`torch.scan` is a prototype feature in PyTorch. It currently
does not support autograd and you may run into miscompiles.
Read more about feature classification at:
https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
Args:
combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> (Tensor, Tensor)``,
or if xs is a pytree ``(pytree, pytree) -> (pytree, pytree)``.
The first input to ``combine_fn`` is the previous or initial scan carry
and the second input element to ``combine_fn`` is a slice of the input along dim.
The first output element of ``combine_fn`` is the next scan carry
and the second output of ``combine_fn`` represents a slice of the output.
This function must be pure, i.e., no lifted arguments are supported at the moment
and may not have any side effects.
init (torch.Tensor or pytree with tensor leaves): The inital scan carry, a tensor, or nested pytree of tensors.
The ``init`` is expected to have the same pytree structure as the first output element (i.e. carry)
of ``combine_fn``.
xs (torch.Tensor or pytree with tensor leaves): The input tensor, or nested pytree of tensors.
Kwargs:
dim (int): the dimension to scan over, default 0.
reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``.
Returns:
final_carry (torch.Tensor or pytree with tensor leaves),
the final carry of the scan operation with same pytree structure as init.
out (torch.Tensor or pytree with tensor leaves),
each tensor leaf is a stacked output along dim, where each slice is the output of a scan iteration.
Example::
def add(x: torch.Tensor, y: torch.Tensor):
next_carry = y = x + y
return next_carry, y
i0 = torch.zeros(1)
xs = torch.arange(1, 5)
# returns torch.tensor([10]), torch.tensor([1., 3., 6., 10.])
last_carry, cumsum = scan(add, init=i0, xs=xs)
"""
if not callable(combine_fn):
raise RuntimeError("Combine_fn must be a callable, but got {combine_fn}")
if not isinstance(dim, int):
raise RuntimeError("Dim must be an int, but got " + str(type(dim)))
if not isinstance(reverse, bool):
raise RuntimeError("Reverse must be a bool, but got " + str(type(reverse)))
# TODO: Support closures/nn_modules in order to be able represent RNNs with scan
# TODO: Support _inductor lowering
# TODO: Support Autograd
# TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc.
# Dynamo is expecting a callable with "__code__" attribute.
# We cannot directly pass cond_op to it. So we wrap it in a dummy function.
def _scan_op_wrapper(*args, **kwargs):
return scan(*args, **kwargs)
if not torch._dynamo.is_compiling():
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
return torch.compile(_scan_op_wrapper, backend="eager", fullgraph=True)(
combine_fn, init, xs, dim=dim, reverse=reverse
)
leaves_init, spec_init = pytree.tree_flatten(init)
leaves_xs, spec_xs = pytree.tree_flatten(xs)
if len(leaves_init) == 0:
raise RuntimeError("Init tensors must be provided")
if any(not isinstance(x, torch.Tensor) for x in leaves_init):
raise RuntimeError("All init leaves must be a Tensor")
if any(not isinstance(x, torch.Tensor) for x in leaves_xs):
raise RuntimeError("All xs leaves must be a Tensor")
if any(x.shape[dim] == 0 for x in leaves_xs):
raise RuntimeError("All xs leaves must have a scan dimension > 0")
if len(leaves_xs) > 0:
shape = leaves_xs[0].shape
ndim = len(shape)
dim = utils.canonicalize_dim(ndim, dim)
out = combine_fn(
pytree.tree_unflatten(leaves_init, spec_init),
pytree.tree_unflatten(
[aten.slice(elem, dim, 0, 1, 1) for elem in leaves_xs], spec_xs
),
)
# The first output needs to have the same pytree as init
carry_leaves = pytree.tree_leaves(out[0])
if len(carry_leaves) != len(leaves_init):
raise RuntimeError(
"The number of leaves of the pytree of the new carry produced by the operator\
needs to match the length of the pytree of the init"
)
if any(
in_l.shape != out_l.shape for in_l, out_l in zip(leaves_init, carry_leaves)
):
raise RuntimeError(
"The pytree of the new carry produced by the operator needs to match the pytree of the init"
)
# There are no pytree restrictions on the second output of the operator
out_leaves, tree_out = pytree.tree_flatten(out[1])
combine_fn = functools.partial(
wrap_combine_fn_flat,
combine_fn=combine_fn,
spec_init=spec_init,
spec_xs=spec_xs,
num_init_leaves=len(leaves_init),
num_inp_leaves=len(leaves_xs),
)
result_carry, result_flat = scan_op(
combine_fn, leaves_init, leaves_xs, dim, reverse
)
return pytree.tree_unflatten(result_carry, spec_init), pytree.tree_unflatten(
result_flat, tree_out
)
else:
return pytree.tree_unflatten(leaves_init, spec_init), xs
class ScanOp(HigherOrderOperator):
def __init__(self):
super().__init__("scan")
def __call__(self, combine_fn, init, xs, dim, reverse):
return super().__call__(combine_fn, init, xs, dim, reverse)
scan_op = ScanOp()
def generic_scan(operator, init, xs, dim=0, reverse=False):
def _scan(init, xs):
"""Perform scan on `elems` using `elems_init."""
carry = init
if len(xs) == 0:
return carry, []
num_elems = xs[0].shape[dim]
if reverse:
ind = num_elems - 1
else:
ind = 0
# Compute dummy shapes for the pre-allocation
dummy_carry, dummy_out = operator(
*carry, *[aten.slice(elem, dim, 0, 1, 1) for elem in xs]
)
output_scanned_dim = dummy_out[0].shape[dim]
# Pre-alocate
# outs -> Output matrix
# idxs -> Index matrix for scatter_
outs, outs_idxs = zip(
*[
[
torch.zeros(
list(e.size())[:dim]
+ [list(e.size())[dim] * num_elems]
+ list(e.size())[dim + 1 :],
dtype=e.dtype,
device=e.device,
),
torch.cat(
[
id * t
for id, t in zip(
range(output_scanned_dim),
torch.tensor_split(
torch.ones_like(e, dtype=torch.int64),
output_scanned_dim,
dim=dim,
),
)
],
dim,
),
]
for i, e in enumerate(dummy_out)
]
)
def store_in_mat(mat, out, d, index, index_modifier):
# Store the intermediate out in the outs matrix
for o, x, idx in zip(mat, out, index):
o.scatter_(d, idx + index_modifier, x)
def cond(i, n, r):
if (r and i < 0) or (not r and i > (n - 1)):
return False
else:
return True
def op(i):
if reverse:
return i - 1
else:
return i + 1
while cond(ind, num_elems, reverse):
carry, out = operator(
*carry,
*[aten.slice(elem, dim, ind, ind + 1, 1) for elem in xs],
)
# Store the inits in the outs matrix.
store_in_mat(outs, out, dim, outs_idxs, ind * output_scanned_dim)
ind = op(ind)
return (carry, list(outs))
scans = _scan(init, xs)
return scans
def make_expanded_output_shape(dim, scan_length, shapes, use_sh=False):
expanded_shapes = [
tuple(
(s if use_sh else -1) if i != dim else scan_length for i, s in enumerate(sh)
)
for sh in shapes
]
return expanded_shapes
def trace_scan(
proxy_mode,
func_overload,
combine_fn: Callable,
init: List[torch.Tensor],
xs: List[torch.Tensor],
dim: int,
reverse: bool,
):
with disable_proxy_modes_tracing():
sample_inits = [
torch.empty_like(
x_init,
dtype=x_init.dtype,
device=x_init.device,
requires_grad=x_init.requires_grad,
)
for x_init in init
]
sample_xs = [
torch.empty_like(
aten.slice(x, dim, 0, 1, 1),
dtype=x.dtype,
device=x.device,
requires_grad=x.requires_grad,
)
for x in xs
]
combine_graph = reenter_make_fx(combine_fn)(*sample_inits, *sample_xs)
outputs = None
for node in combine_graph.graph.nodes:
if node.op == "output":
assert outputs is None
assert len(node.args) == 1
outputs = node.args[0]
assert outputs is not None
if len(outputs) != 2:
raise RuntimeError(
f"Expected to return 2 outputs: carry, out_matrix, but got:"
f"\n {len(outputs)} elements"
)
for ini, carry in zip(init, outputs[0]):
ini_meta = ini
carry_meta = carry.meta["tensor_meta"]
carry_val = carry.meta["val"]
if (
carry_val.device != ini_meta.device
or carry_meta.dtype != ini_meta.dtype
or carry_meta.shape != ini_meta.shape
):
raise RuntimeError(
f"Expected metadata of the combine_fn result {carry_meta} to be the same as "
+ f"the metadata of init with {ini_meta}"
)
_, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")
proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)
args = (combine_graph, init, xs, dim, reverse)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="scan"
)
with disable_proxy_modes_tracing():
scan_length = xs[0].shape[dim]
fake_out_shapes = make_expanded_output_shape(
dim, scan_length, [o.meta["val"].size() for o in outputs[1]]
)
def expand_tensor(t, sh):
if isinstance(t, torch.Tensor):
return t.expand(*sh)
return t
expanded_outs = [
pytree.tree_map(expand_tensor, t.meta["val"], sh)
for t, sh in zip(outputs[1], fake_out_shapes)
]
out = (init, expanded_outs)
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def scan_op_dense(combine_fn, init, xs, dim, reverse):
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return generic_scan(combine_fn, init, xs, dim, reverse)
scan_op.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(scan_op, deferred_error=True)
)
@scan_op.py_impl(ProxyTorchDispatchMode)
def scan_proxy_mode(mode, combine_fn, init, xs, dim, reverse):
return trace_scan(mode, scan_op, combine_fn, init, xs, dim, reverse)
@scan_op.py_impl(FakeTensorMode)
def scan_fake_tensor_mode(mode, combine_fn, init, xs, dim, reverse):
with mode:
dim_len = xs[0].shape[dim]
carry, outputs = combine_fn(
*init, *[aten.slice(inp, dim, 0, 1, 1) for inp in xs]
)
fake_out_shapes = [
tuple(-1 if i != dim else dim_len for i, sh in enumerate(o.size()))
for o in outputs
]
out = (
carry,
tuple(t.expand(*sh).clone() for t, sh in zip(outputs, fake_out_shapes)),
)
return out
@scan_op.py_functionalize_impl
def scan_functionalize(ctx, combine_fn, init, xs, dim, reverse):
unwrapped_xs = ctx.unwrap_tensors(xs)
unwrapped_init = ctx.unwrap_tensors(init)
with ctx.redispatch_to_next() as m:
functional_combine_fn = ctx.functionalize(combine_fn)
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
sample_xs = list(itertools.chain(unwrapped_init, unwrapped_init))
if _has_potential_branch_input_mutation(
functional_combine_fn, sample_xs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"Combine_fn might be modifying the input!"
)
if _has_potential_branch_input_alias(
functional_combine_fn, sample_xs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
"Combine_fn might be aliasing the input!"
)
ret = scan_op(functional_combine_fn, unwrapped_init, unwrapped_xs, dim, reverse)
return ctx.wrap_tensors(ret)

View File

@ -224,6 +224,7 @@ def _from_fun(t):
t.stride(),
dtype=t.dtype,
requires_grad=t.requires_grad,
device=t.device,
)
else:
# clone of a functional tensor produces a functional tensor

View File

@ -6218,12 +6218,12 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
@register_lowering(associative_scan_op, type_promotion_kind=None)
def associative_scan(combine_fn: ir.Subgraph, input, dim: int):
def associative_scan(combine_fn: ir.Subgraph, xs, dim: int):
from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph
subgraph_inputs = [
InputDescriptor(dtype=x.get_dtype(), device=x.get_device())
for x in itertools.chain(input, input)
for x in itertools.chain(xs, xs)
]
lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated]
@ -6233,9 +6233,9 @@ def associative_scan(combine_fn: ir.Subgraph, input, dim: int):
*pytree.tree_leaves(rhs),
)
kwargs = _make_scan_inner(input[0], axis=dim, dtype=None)
kwargs["dtypes"] = tuple(x.get_dtype() for x in input)
kwargs["inner_fns"] = tuple(x.make_loader() for x in input)
kwargs = _make_scan_inner(xs[0], axis=dim, dtype=None)
kwargs["dtypes"] = tuple(x.get_dtype() for x in xs)
kwargs["inner_fns"] = tuple(x.make_loader() for x in xs)
result = ir.Scan.create(
combine_fn=wrapped_combine_fn,
can_fallback_to_aten=False,