mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implementation of scan (#134102)
This operation is supposed to be the pendant to the `associative_scan`, but can operate with non-associative functions. @ydwu4 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134102 Approved by: https://github.com/ydwu4
This commit is contained in:
committed by
PyTorch MergeBot
parent
6546c6186d
commit
e889252493
File diff suppressed because it is too large
Load Diff
@ -23,6 +23,7 @@ from torch import distributed as dist
|
||||
from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor
|
||||
from torch._dynamo.testing import make_test_cls_with_patches, rand_strided
|
||||
from torch._guards import tracing, TracingContext
|
||||
from torch._higher_order_ops.scan import scan
|
||||
from torch._subclasses.fake_tensor import (
|
||||
DynamicOutputShapeException,
|
||||
extract_tensor_metadata,
|
||||
@ -923,6 +924,21 @@ class FakeTensorTest(TestCase):
|
||||
self.assertIsInstance(r, FakeTensor)
|
||||
self.assertEqual(r.size(), [3])
|
||||
|
||||
@parametrize("reverse", [False, True])
|
||||
def test_scan(self, reverse):
|
||||
def add(x, y):
|
||||
return x + y, x + y
|
||||
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
x = torch.randn((3, 5, 7), device="cpu")
|
||||
init = torch.randn((3, 1, 7), device="cpu")
|
||||
r = scan(add, init, x, dim=1, reverse=reverse)
|
||||
|
||||
self.assertIsInstance(r[0], FakeTensor)
|
||||
self.assertIsInstance(r[1], FakeTensor)
|
||||
self.assertEqual(r[0].size(), init.size())
|
||||
self.assertEqual(r[1].size(), x.size())
|
||||
|
||||
|
||||
instantiate_parametrized_tests(FakeTensorTest)
|
||||
|
||||
|
@ -3188,6 +3188,7 @@ LEGACY_MOD_INLINELIST = {
|
||||
"torch._higher_order_ops.cond",
|
||||
"torch._higher_order_ops.while_loop",
|
||||
"torch._higher_order_ops.associative_scan",
|
||||
"torch._higher_order_ops.scan",
|
||||
"torch.nn.attention.flex_attention",
|
||||
"torch.ao.quantization.pt2e.export_utils",
|
||||
"torch.ao.quantization.pt2e.qat_utils",
|
||||
@ -3228,6 +3229,7 @@ MOD_INLINELIST = [
|
||||
"torch._functorch.functional_call",
|
||||
"torch._functorch.vmap",
|
||||
"torch._higher_order_ops.associative_scan",
|
||||
"torch._higher_order_ops.scan",
|
||||
"torch._higher_order_ops.strict_mode",
|
||||
"torch._higher_order_ops.while_loop",
|
||||
"torch._inductor.test_operators",
|
||||
|
@ -602,6 +602,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
|
||||
return RunWithRNGStateHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "associative_scan":
|
||||
return AssociativeScanHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "scan":
|
||||
return ScanHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "call_torchbind":
|
||||
return CallTorchbindHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "wrap_with_set_grad_enabled":
|
||||
@ -1022,16 +1024,16 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
|
||||
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
|
||||
|
||||
def arg_extractor(combine_fn, input, dim):
|
||||
return combine_fn, input, dim
|
||||
def arg_extractor(combine_fn, xs, dim):
|
||||
return combine_fn, xs, dim
|
||||
|
||||
combine_fn, input, dim = arg_extractor(*args, **kwargs)
|
||||
combine_fn, xs, dim = arg_extractor(*args, **kwargs)
|
||||
|
||||
if input.python_type() != list:
|
||||
if xs.python_type() != list:
|
||||
unimplemented(
|
||||
f"Expected input to be a list of tensors but got {input.python_type()}",
|
||||
f"Expected xs to be a list of tensors but got {xs.python_type()}",
|
||||
)
|
||||
assert isinstance(input, torch._dynamo.variables.lists.BaseListVariable)
|
||||
assert isinstance(xs, torch._dynamo.variables.lists.BaseListVariable)
|
||||
|
||||
# Trace the subgraph
|
||||
# TODO: Fix these pointless new_empty calls appearing in the dynamo output graph.
|
||||
@ -1054,7 +1056,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
"requires_grad": SourcelessBuilder.create(tx, leaf.requires_grad),
|
||||
},
|
||||
)
|
||||
for leaf in itertools.chain(input.items, input.items)
|
||||
for leaf in itertools.chain(xs.items, xs.items)
|
||||
]
|
||||
(
|
||||
(combine_result, combine_treespec),
|
||||
@ -1065,7 +1067,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
combine_fn,
|
||||
sub_args,
|
||||
sub_kwargs={},
|
||||
description="scan_combine",
|
||||
description="associative_scan_combine_fn",
|
||||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
)
|
||||
@ -1080,9 +1082,9 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}",
|
||||
)
|
||||
|
||||
input_proxy = input.as_proxy()
|
||||
xs_proxy = xs.as_proxy()
|
||||
combine_result_proxy = combine_result.as_proxy()
|
||||
for result, inp_proxy in zip(combine_result_proxy, input_proxy):
|
||||
for result, inp_proxy in zip(combine_result_proxy, xs_proxy):
|
||||
inp_meta = inp_proxy.node.meta["example_value"]
|
||||
combine_result_meta = result.node.meta["example_value"]
|
||||
if combine_result_meta.device != inp_meta.device:
|
||||
@ -1097,18 +1099,17 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
|
||||
combine_fn_name = add_subgraph(tx, "scan_combine", combine_gm)
|
||||
combine_fn_name = add_subgraph(tx, "associative_scan_combine_fn", combine_gm)
|
||||
|
||||
p_args = (
|
||||
make_attr(tx, combine_fn_name),
|
||||
input_proxy,
|
||||
xs_proxy,
|
||||
dim.as_proxy(),
|
||||
)
|
||||
|
||||
with tx.fake_mode:
|
||||
out_meta = tuple(
|
||||
inp_proxy.node.meta["example_value"].clone()
|
||||
for inp_proxy in input_proxy
|
||||
inp_proxy.node.meta["example_value"].clone() for inp_proxy in xs_proxy
|
||||
)
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
@ -1119,6 +1120,196 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
|
||||
class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
@raise_hard_error_if_graph_break(
|
||||
reason="scan must be captured completely with torch.compile."
|
||||
)
|
||||
def call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: List[VariableTracker],
|
||||
kwargs: Dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
from torch._higher_order_ops.scan import make_expanded_output_shape
|
||||
|
||||
from .builder import SourcelessBuilder, wrap_fx_proxy
|
||||
|
||||
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
|
||||
|
||||
def arg_extractor(combine_fn, init, xs, dim, reverse):
|
||||
return combine_fn, init, xs, dim, reverse
|
||||
|
||||
combine_fn, init, xs, dim, reverse = arg_extractor(*args, **kwargs)
|
||||
|
||||
if xs.python_type() != list:
|
||||
unimplemented(
|
||||
f"Expected xs to be a list of tensors but got {xs.python_type()}",
|
||||
)
|
||||
assert isinstance(xs, torch._dynamo.variables.lists.BaseListVariable)
|
||||
if init.python_type() != list:
|
||||
unimplemented(
|
||||
f"Expected init to be a list of tensors but got {init.python_type()}",
|
||||
)
|
||||
assert isinstance(init, torch._dynamo.variables.lists.BaseListVariable)
|
||||
|
||||
dim_fake = (
|
||||
dim.as_proxy()
|
||||
if type(dim.as_proxy()) == int
|
||||
else get_fake_value(dim.as_proxy().node, tx)
|
||||
)
|
||||
scan_length = get_fake_value(xs.items[0].as_proxy().node, tx).size()[dim_fake]
|
||||
if scan_length == 0:
|
||||
unimplemented(
|
||||
"scan() operator doesn't support zero-sized tensors during tracing."
|
||||
)
|
||||
|
||||
init_len = len(init.items)
|
||||
if init_len == 0:
|
||||
unimplemented("scan() operator requires init leaves.")
|
||||
|
||||
# Trace the subgraph
|
||||
# TODO: Fix these pointless new_empty calls appearing in the dynamo output graph.
|
||||
# TODO: Unify handling of sub_args across control flow ops, such as cond, while_loop, etc.
|
||||
sub_args_init = [
|
||||
ini.call_method(
|
||||
tx,
|
||||
"new_empty",
|
||||
args=(
|
||||
SourcelessBuilder.create(
|
||||
tx,
|
||||
ini.size
|
||||
if ini.size is not None
|
||||
else tuple(
|
||||
BuiltinVariable(getattr)
|
||||
.call_function(
|
||||
tx, [ini, ConstantVariable.create("shape")], {}
|
||||
)
|
||||
.items
|
||||
),
|
||||
),
|
||||
),
|
||||
kwargs={
|
||||
"dtype": SourcelessBuilder.create(tx, ini.dtype),
|
||||
"device": SourcelessBuilder.create(tx, ini.device),
|
||||
"requires_grad": SourcelessBuilder.create(tx, ini.requires_grad),
|
||||
},
|
||||
)
|
||||
for ini in init.items
|
||||
]
|
||||
sub_args_inp_shapes = make_expanded_output_shape(
|
||||
dim_fake,
|
||||
1,
|
||||
[
|
||||
tuple(
|
||||
BuiltinVariable(getattr)
|
||||
.call_function(tx, [inp, ConstantVariable.create("shape")], {})
|
||||
.items
|
||||
)
|
||||
for inp in xs.items
|
||||
],
|
||||
True,
|
||||
)
|
||||
sub_args_inp = [
|
||||
inp.call_method(
|
||||
tx,
|
||||
"new_empty",
|
||||
args=(SourcelessBuilder.create(tx, inp_sh),),
|
||||
kwargs={
|
||||
"dtype": SourcelessBuilder.create(tx, inp.dtype),
|
||||
"device": SourcelessBuilder.create(tx, inp.device),
|
||||
"requires_grad": SourcelessBuilder.create(tx, inp.requires_grad),
|
||||
},
|
||||
)
|
||||
for inp, inp_sh in zip(xs.items, sub_args_inp_shapes)
|
||||
]
|
||||
sub_args = sub_args_init + sub_args_inp
|
||||
(
|
||||
(combine_result, combine_treespec),
|
||||
combine_graph,
|
||||
combine_lifted_freevars,
|
||||
) = speculate_subgraph(
|
||||
tx,
|
||||
combine_fn,
|
||||
sub_args,
|
||||
sub_kwargs={},
|
||||
description="scan_combine_fn",
|
||||
source_target=self.value,
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
)
|
||||
|
||||
if combine_lifted_freevars:
|
||||
unimplemented(
|
||||
f"Combine fn had unexpected freevars: {combine_lifted_freevars}"
|
||||
)
|
||||
|
||||
if any(cr.python_type() != list for cr in combine_result.items):
|
||||
unimplemented(
|
||||
f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}",
|
||||
)
|
||||
|
||||
xs_proxy = xs.as_proxy()
|
||||
init_proxy = init.as_proxy()
|
||||
combine_carry_proxy = combine_result.items[0].as_proxy()
|
||||
|
||||
# Checks for carry and init
|
||||
for ini_proxy, carry in zip(init_proxy, combine_carry_proxy):
|
||||
ini_meta = ini_proxy.node.meta["example_value"]
|
||||
carry_meta = carry.node.meta["example_value"]
|
||||
if (
|
||||
carry_meta.device != ini_meta.device
|
||||
or carry_meta.dtype != ini_meta.dtype
|
||||
or carry_meta.shape != ini_meta.shape
|
||||
):
|
||||
unimplemented(
|
||||
f"Expected metadata of the combine_fn result {carry_meta} to be the same as "
|
||||
+ f"the metadata of init with {ini_meta}"
|
||||
)
|
||||
|
||||
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
|
||||
combine_fn_name = add_subgraph(tx, "scan_combine_fn", combine_gm)
|
||||
|
||||
p_args = (
|
||||
make_attr(tx, combine_fn_name),
|
||||
init_proxy,
|
||||
xs_proxy,
|
||||
dim.as_proxy(),
|
||||
reverse.as_proxy(),
|
||||
)
|
||||
|
||||
with tx.fake_mode:
|
||||
# For the fake mode, we need to duplicate the init tensor along the dim
|
||||
# to have the same size as the xs arguments
|
||||
# We also do a clone with contiguous_format. This is to be consistent with
|
||||
# eager semantic of map, which stacks the outputs. The result is contiguous
|
||||
# as a result of the stack operation.
|
||||
fake_out_shapes = make_expanded_output_shape(
|
||||
dim_fake,
|
||||
scan_length,
|
||||
[
|
||||
get_fake_value(o.as_proxy().node, tx).size()
|
||||
for o in combine_result.items[1].items
|
||||
],
|
||||
)
|
||||
out_meta = (
|
||||
[init_p.node.meta["example_value"].clone() for init_p in init_proxy],
|
||||
list( # noqa: C400
|
||||
t.as_proxy()
|
||||
.node.meta["example_value"]
|
||||
.expand(*sh)
|
||||
.clone(memory_format=torch.contiguous_format)
|
||||
for t, sh in zip(combine_result.items[1].items, fake_out_shapes)
|
||||
),
|
||||
)
|
||||
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function", torch.ops.higher_order.scan, p_args, {}
|
||||
),
|
||||
example_value=out_meta,
|
||||
)
|
||||
|
||||
|
||||
def non_single_tensor_return_unsupported(api, ret):
|
||||
from . import TensorVariable
|
||||
|
||||
|
@ -73,8 +73,8 @@ class AssociativeScanOp(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("associative_scan")
|
||||
|
||||
def __call__(self, combine_fn, input, dim):
|
||||
return super().__call__(combine_fn, input, dim)
|
||||
def __call__(self, combine_fn, xs, dim):
|
||||
return super().__call__(combine_fn, xs, dim)
|
||||
|
||||
|
||||
associative_scan_op = AssociativeScanOp()
|
||||
@ -82,13 +82,13 @@ associative_scan_op = AssociativeScanOp()
|
||||
|
||||
def associative_scan(
|
||||
combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree],
|
||||
input: pytree.PyTree,
|
||||
xs: pytree.PyTree,
|
||||
dim: int,
|
||||
reverse: bool = False,
|
||||
combine_mode: str = "pointwise",
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Performs an inclusive scan with an associative pointwise combine function.
|
||||
Performs an inclusive scan with an associative combine function.
|
||||
|
||||
.. warning::
|
||||
`torch.associative_scan` is a prototype feature in PyTorch. It currently
|
||||
@ -102,14 +102,15 @@ def associative_scan(
|
||||
Args:
|
||||
combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``,
|
||||
or if input is a pytree ``(pytree, pytree) -> pytree``.
|
||||
This function must be pure, pointwise, and satisfy the associative property.
|
||||
input (torch.Tensor): The input tensor, or nested pytree of tensors.
|
||||
This function must be pure, i.e., no lifted arguments are supported at the moment,
|
||||
satisfy the associative property and have no side-effects.
|
||||
xs (torch.Tensor): The input tensor, or nested pytree of tensors.
|
||||
All inputs are expected to have the same shape.
|
||||
dim (int): the dimension to scan over
|
||||
reverse (bool): A boolean stating if the scan should be reversed with respect to the dimension.
|
||||
combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``.
|
||||
reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``.
|
||||
combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``, default ``pointwise``.
|
||||
If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations
|
||||
and ``input`` must be CUDA tensors.
|
||||
and ``xs`` must be CUDA tensors.
|
||||
In all other cases ``combine_mode=generic`` should be used.
|
||||
Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``.
|
||||
|
||||
@ -122,27 +123,32 @@ def associative_scan(
|
||||
cumsum = associative_scan(add, x, dim)
|
||||
|
||||
"""
|
||||
assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}"
|
||||
assert isinstance(dim, int), "dim must be an int, but got {type(dim)}"
|
||||
assert combine_mode in ["pointwise", "generic"]
|
||||
if not callable(combine_fn):
|
||||
raise RuntimeError("Combine_fn must be a callable, but got {combine_fn}")
|
||||
if not isinstance(dim, int):
|
||||
raise RuntimeError("Dim must be an int, but got " + str(type(dim)))
|
||||
if combine_mode not in ["pointwise", "generic"]:
|
||||
raise RuntimeError(
|
||||
"Combine_mode must either 'pointwise' or 'generic', but got {combine_mode}"
|
||||
)
|
||||
|
||||
if not torch._dynamo.is_compiling():
|
||||
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
|
||||
return torch.compile(associative_scan, fullgraph=True)(
|
||||
combine_fn, input, dim, reverse=reverse, combine_mode=combine_mode
|
||||
combine_fn, xs, dim, reverse=reverse, combine_mode=combine_mode
|
||||
)
|
||||
|
||||
leaves, spec = pytree.tree_flatten(input)
|
||||
leaves, spec = pytree.tree_flatten(xs)
|
||||
|
||||
if combine_mode == "pointwise" and not all(l.device.type == "cuda" for l in leaves):
|
||||
raise ValueError(
|
||||
"For combine_mode='pointwise', all input tensors need to be on CUDA"
|
||||
)
|
||||
|
||||
assert len(leaves) >= 1, "expected at least 1 input leaf"
|
||||
assert all(
|
||||
isinstance(x, torch.Tensor) for x in leaves
|
||||
), "input leaves must be a Tensor"
|
||||
if len(leaves) == 0:
|
||||
raise RuntimeError("Expected at least 1 xs leaf")
|
||||
if any(not isinstance(x, torch.Tensor) for x in leaves):
|
||||
raise RuntimeError("xs leaves must be a Tensor")
|
||||
|
||||
if reverse:
|
||||
leaves = [torch.flip(elem, [dim]) for elem in leaves]
|
||||
@ -152,20 +158,21 @@ def associative_scan(
|
||||
dim = utils.canonicalize_dim(ndim, dim)
|
||||
|
||||
for x in leaves[1:]:
|
||||
assert x.shape == shape, "All input tensors must have the same shape"
|
||||
assert x.shape == shape, "All xs tensors must have the same shape"
|
||||
|
||||
out = combine_fn(
|
||||
pytree.tree_unflatten(leaves, spec),
|
||||
pytree.tree_unflatten(leaves, spec),
|
||||
)
|
||||
out_leaves, tree_out = pytree.tree_flatten(out)
|
||||
assert len(leaves) == len(
|
||||
out_leaves
|
||||
), "The pytree of the output of the operator needs to match the input pytree"
|
||||
for x in out_leaves:
|
||||
assert (
|
||||
x.shape == shape
|
||||
), "The pytree of the output of the operator needs to match the input pytree"
|
||||
if len(leaves) != len(out_leaves):
|
||||
raise RuntimeError(
|
||||
"The number of leaves of the pytree of the output of the operator needs to match the length of the pytree of the input"
|
||||
)
|
||||
if any(x.shape != shape for x in out_leaves):
|
||||
raise RuntimeError(
|
||||
"The pytree of the output of the operator needs to match the xs pytree"
|
||||
)
|
||||
|
||||
combine_fn = functools.partial(
|
||||
wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(leaves)
|
||||
@ -194,7 +201,7 @@ def generic_associative_scan(operator, elems_flat, dim=0):
|
||||
or if input is a pytree ``(pytree, pytree) -> pytree``.
|
||||
This function must be pure, pointwise, and satisfy the associative property.
|
||||
elems_flat (torch.Tensor): A list of torch.Tensors converted from the pytree of
|
||||
``input`` provided to ``associative_scan``.
|
||||
``xs`` provided to ``associative_scan``.
|
||||
All inputs are expected to have the same shape.
|
||||
dim (int): the dimension to scan over
|
||||
|
||||
@ -279,19 +286,19 @@ def generic_associative_scan(operator, elems_flat, dim=0):
|
||||
|
||||
|
||||
def trace_associative_scan(
|
||||
proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int
|
||||
proxy_mode, func_overload, combine_fn: Callable, xs: List[torch.Tensor], dim: int
|
||||
):
|
||||
with disable_proxy_modes_tracing():
|
||||
sample_inputs = [
|
||||
sample_xs = [
|
||||
torch.empty_like(
|
||||
x,
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
requires_grad=x.requires_grad,
|
||||
)
|
||||
for x in itertools.chain(input, input)
|
||||
for x in itertools.chain(xs, xs)
|
||||
]
|
||||
combine_graph = reenter_make_fx(combine_fn)(*sample_inputs)
|
||||
combine_graph = reenter_make_fx(combine_fn)(*sample_xs)
|
||||
|
||||
outputs = None
|
||||
for node in combine_graph.graph.nodes:
|
||||
@ -307,10 +314,10 @@ def trace_associative_scan(
|
||||
|
||||
assert outputs is not None
|
||||
assert len(outputs) == len(
|
||||
input
|
||||
), f"expected combine_fn to return {len(input)} results but got {len(outputs)}"
|
||||
xs
|
||||
), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}"
|
||||
|
||||
for i, o in zip(input, outputs):
|
||||
for i, o in zip(xs, outputs):
|
||||
o_meta = o.meta["tensor_meta"]
|
||||
assert o_meta.dtype == i.dtype, (
|
||||
f"combine_fn output type mismatch, expected {i.dtype} "
|
||||
@ -321,20 +328,20 @@ def trace_associative_scan(
|
||||
|
||||
proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)
|
||||
|
||||
args = (combine_graph, input, dim)
|
||||
args = (combine_graph, xs, dim)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", func_overload, proxy_args, {}, name="associative_scan"
|
||||
)
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
out = [aten.clone(x) for x in input]
|
||||
out = [aten.clone(x) for x in xs]
|
||||
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
|
||||
|
||||
@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def associative_scan_op_dense(combine_fn, input, dim):
|
||||
def associative_scan_op_dense(combine_fn, xs, dim):
|
||||
raise NotImplementedError("associative_scan is not implemented for eager")
|
||||
|
||||
|
||||
@ -344,22 +351,22 @@ associative_scan_op.py_impl(DispatchKey.Autograd)(
|
||||
|
||||
|
||||
@associative_scan_op.py_impl(ProxyTorchDispatchMode)
|
||||
def associative_scan_proxy_mode(mode, combine_fn, input, dim):
|
||||
return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim)
|
||||
def associative_scan_proxy_mode(mode, combine_fn, xs, dim):
|
||||
return trace_associative_scan(mode, associative_scan_op, combine_fn, xs, dim)
|
||||
|
||||
|
||||
@associative_scan_op.py_impl(FakeTensorMode)
|
||||
def assoiciative_scan_fake_tensor_mode(mode, combine_fn, input, dim):
|
||||
def assoiciative_scan_fake_tensor_mode(mode, combine_fn, xs, dim):
|
||||
with mode:
|
||||
return [x.clone() for x in input]
|
||||
return [x.clone() for x in xs]
|
||||
|
||||
|
||||
@associative_scan_op.py_functionalize_impl
|
||||
def associative_scan_functionalize(ctx, combine_fn, input, dim):
|
||||
unwrapped_input = ctx.unwrap_tensors(input)
|
||||
def associative_scan_functionalize(ctx, combine_fn, xs, dim):
|
||||
unwrapped_xs = ctx.unwrap_tensors(xs)
|
||||
with ctx.redispatch_to_next() as m:
|
||||
functional_combine_fn = ctx.functionalize(
|
||||
_maybe_run_with_interpreter(combine_fn)
|
||||
)
|
||||
ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim)
|
||||
ret = associative_scan_op(functional_combine_fn, unwrapped_xs, dim)
|
||||
return ctx.wrap_tensors(ret)
|
||||
|
438
torch/_higher_order_ops/scan.py
Normal file
438
torch/_higher_order_ops/scan.py
Normal file
@ -0,0 +1,438 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import itertools
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
import torch._subclasses.functional_tensor
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_set_compilation_env,
|
||||
autograd_not_implemented,
|
||||
reenter_make_fx,
|
||||
unique_graph_id,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
|
||||
aten = torch._ops.ops.aten
|
||||
|
||||
|
||||
def wrap_combine_fn_flat(
|
||||
*args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves
|
||||
):
|
||||
assert len(args) == (num_init_leaves + num_inp_leaves)
|
||||
carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init)
|
||||
xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs)
|
||||
carry, combined = combine_fn(carry, xs)
|
||||
carry_flat = pytree.tree_leaves(carry)
|
||||
combined_flat = pytree.tree_leaves(combined)
|
||||
assert num_init_leaves == len(carry_flat)
|
||||
return (carry_flat, combined_flat)
|
||||
|
||||
|
||||
def scan(
|
||||
combine_fn: Callable[
|
||||
[pytree.PyTree, pytree.PyTree], Tuple[pytree.PyTree, pytree.PyTree]
|
||||
],
|
||||
init: pytree.PyTree,
|
||||
xs: pytree.PyTree,
|
||||
/,
|
||||
*,
|
||||
dim: int = 0,
|
||||
reverse: bool = False,
|
||||
) -> Tuple[pytree.PyTree, pytree.PyTree]:
|
||||
r"""
|
||||
Performs an inclusive scan with a combine function.
|
||||
|
||||
.. warning::
|
||||
`torch.scan` is a prototype feature in PyTorch. It currently
|
||||
does not support autograd and you may run into miscompiles.
|
||||
Read more about feature classification at:
|
||||
https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
|
||||
|
||||
Args:
|
||||
combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> (Tensor, Tensor)``,
|
||||
or if xs is a pytree ``(pytree, pytree) -> (pytree, pytree)``.
|
||||
The first input to ``combine_fn`` is the previous or initial scan carry
|
||||
and the second input element to ``combine_fn`` is a slice of the input along dim.
|
||||
The first output element of ``combine_fn`` is the next scan carry
|
||||
and the second output of ``combine_fn`` represents a slice of the output.
|
||||
This function must be pure, i.e., no lifted arguments are supported at the moment
|
||||
and may not have any side effects.
|
||||
init (torch.Tensor or pytree with tensor leaves): The inital scan carry, a tensor, or nested pytree of tensors.
|
||||
The ``init`` is expected to have the same pytree structure as the first output element (i.e. carry)
|
||||
of ``combine_fn``.
|
||||
xs (torch.Tensor or pytree with tensor leaves): The input tensor, or nested pytree of tensors.
|
||||
|
||||
Kwargs:
|
||||
dim (int): the dimension to scan over, default 0.
|
||||
reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``.
|
||||
|
||||
Returns:
|
||||
final_carry (torch.Tensor or pytree with tensor leaves),
|
||||
the final carry of the scan operation with same pytree structure as init.
|
||||
out (torch.Tensor or pytree with tensor leaves),
|
||||
each tensor leaf is a stacked output along dim, where each slice is the output of a scan iteration.
|
||||
|
||||
Example::
|
||||
|
||||
def add(x: torch.Tensor, y: torch.Tensor):
|
||||
next_carry = y = x + y
|
||||
return next_carry, y
|
||||
|
||||
i0 = torch.zeros(1)
|
||||
xs = torch.arange(1, 5)
|
||||
# returns torch.tensor([10]), torch.tensor([1., 3., 6., 10.])
|
||||
last_carry, cumsum = scan(add, init=i0, xs=xs)
|
||||
|
||||
|
||||
"""
|
||||
if not callable(combine_fn):
|
||||
raise RuntimeError("Combine_fn must be a callable, but got {combine_fn}")
|
||||
if not isinstance(dim, int):
|
||||
raise RuntimeError("Dim must be an int, but got " + str(type(dim)))
|
||||
if not isinstance(reverse, bool):
|
||||
raise RuntimeError("Reverse must be a bool, but got " + str(type(reverse)))
|
||||
|
||||
# TODO: Support closures/nn_modules in order to be able represent RNNs with scan
|
||||
# TODO: Support _inductor lowering
|
||||
# TODO: Support Autograd
|
||||
# TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc.
|
||||
|
||||
# Dynamo is expecting a callable with "__code__" attribute.
|
||||
# We cannot directly pass cond_op to it. So we wrap it in a dummy function.
|
||||
def _scan_op_wrapper(*args, **kwargs):
|
||||
return scan(*args, **kwargs)
|
||||
|
||||
if not torch._dynamo.is_compiling():
|
||||
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
|
||||
return torch.compile(_scan_op_wrapper, backend="eager", fullgraph=True)(
|
||||
combine_fn, init, xs, dim=dim, reverse=reverse
|
||||
)
|
||||
|
||||
leaves_init, spec_init = pytree.tree_flatten(init)
|
||||
leaves_xs, spec_xs = pytree.tree_flatten(xs)
|
||||
|
||||
if len(leaves_init) == 0:
|
||||
raise RuntimeError("Init tensors must be provided")
|
||||
if any(not isinstance(x, torch.Tensor) for x in leaves_init):
|
||||
raise RuntimeError("All init leaves must be a Tensor")
|
||||
if any(not isinstance(x, torch.Tensor) for x in leaves_xs):
|
||||
raise RuntimeError("All xs leaves must be a Tensor")
|
||||
if any(x.shape[dim] == 0 for x in leaves_xs):
|
||||
raise RuntimeError("All xs leaves must have a scan dimension > 0")
|
||||
|
||||
if len(leaves_xs) > 0:
|
||||
shape = leaves_xs[0].shape
|
||||
ndim = len(shape)
|
||||
dim = utils.canonicalize_dim(ndim, dim)
|
||||
|
||||
out = combine_fn(
|
||||
pytree.tree_unflatten(leaves_init, spec_init),
|
||||
pytree.tree_unflatten(
|
||||
[aten.slice(elem, dim, 0, 1, 1) for elem in leaves_xs], spec_xs
|
||||
),
|
||||
)
|
||||
|
||||
# The first output needs to have the same pytree as init
|
||||
carry_leaves = pytree.tree_leaves(out[0])
|
||||
if len(carry_leaves) != len(leaves_init):
|
||||
raise RuntimeError(
|
||||
"The number of leaves of the pytree of the new carry produced by the operator\
|
||||
needs to match the length of the pytree of the init"
|
||||
)
|
||||
if any(
|
||||
in_l.shape != out_l.shape for in_l, out_l in zip(leaves_init, carry_leaves)
|
||||
):
|
||||
raise RuntimeError(
|
||||
"The pytree of the new carry produced by the operator needs to match the pytree of the init"
|
||||
)
|
||||
|
||||
# There are no pytree restrictions on the second output of the operator
|
||||
out_leaves, tree_out = pytree.tree_flatten(out[1])
|
||||
|
||||
combine_fn = functools.partial(
|
||||
wrap_combine_fn_flat,
|
||||
combine_fn=combine_fn,
|
||||
spec_init=spec_init,
|
||||
spec_xs=spec_xs,
|
||||
num_init_leaves=len(leaves_init),
|
||||
num_inp_leaves=len(leaves_xs),
|
||||
)
|
||||
|
||||
result_carry, result_flat = scan_op(
|
||||
combine_fn, leaves_init, leaves_xs, dim, reverse
|
||||
)
|
||||
|
||||
return pytree.tree_unflatten(result_carry, spec_init), pytree.tree_unflatten(
|
||||
result_flat, tree_out
|
||||
)
|
||||
|
||||
else:
|
||||
return pytree.tree_unflatten(leaves_init, spec_init), xs
|
||||
|
||||
|
||||
class ScanOp(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("scan")
|
||||
|
||||
def __call__(self, combine_fn, init, xs, dim, reverse):
|
||||
return super().__call__(combine_fn, init, xs, dim, reverse)
|
||||
|
||||
|
||||
scan_op = ScanOp()
|
||||
|
||||
|
||||
def generic_scan(operator, init, xs, dim=0, reverse=False):
|
||||
def _scan(init, xs):
|
||||
"""Perform scan on `elems` using `elems_init."""
|
||||
carry = init
|
||||
if len(xs) == 0:
|
||||
return carry, []
|
||||
|
||||
num_elems = xs[0].shape[dim]
|
||||
if reverse:
|
||||
ind = num_elems - 1
|
||||
else:
|
||||
ind = 0
|
||||
|
||||
# Compute dummy shapes for the pre-allocation
|
||||
dummy_carry, dummy_out = operator(
|
||||
*carry, *[aten.slice(elem, dim, 0, 1, 1) for elem in xs]
|
||||
)
|
||||
output_scanned_dim = dummy_out[0].shape[dim]
|
||||
|
||||
# Pre-alocate
|
||||
# outs -> Output matrix
|
||||
# idxs -> Index matrix for scatter_
|
||||
outs, outs_idxs = zip(
|
||||
*[
|
||||
[
|
||||
torch.zeros(
|
||||
list(e.size())[:dim]
|
||||
+ [list(e.size())[dim] * num_elems]
|
||||
+ list(e.size())[dim + 1 :],
|
||||
dtype=e.dtype,
|
||||
device=e.device,
|
||||
),
|
||||
torch.cat(
|
||||
[
|
||||
id * t
|
||||
for id, t in zip(
|
||||
range(output_scanned_dim),
|
||||
torch.tensor_split(
|
||||
torch.ones_like(e, dtype=torch.int64),
|
||||
output_scanned_dim,
|
||||
dim=dim,
|
||||
),
|
||||
)
|
||||
],
|
||||
dim,
|
||||
),
|
||||
]
|
||||
for i, e in enumerate(dummy_out)
|
||||
]
|
||||
)
|
||||
|
||||
def store_in_mat(mat, out, d, index, index_modifier):
|
||||
# Store the intermediate out in the outs matrix
|
||||
for o, x, idx in zip(mat, out, index):
|
||||
o.scatter_(d, idx + index_modifier, x)
|
||||
|
||||
def cond(i, n, r):
|
||||
if (r and i < 0) or (not r and i > (n - 1)):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def op(i):
|
||||
if reverse:
|
||||
return i - 1
|
||||
else:
|
||||
return i + 1
|
||||
|
||||
while cond(ind, num_elems, reverse):
|
||||
carry, out = operator(
|
||||
*carry,
|
||||
*[aten.slice(elem, dim, ind, ind + 1, 1) for elem in xs],
|
||||
)
|
||||
|
||||
# Store the inits in the outs matrix.
|
||||
store_in_mat(outs, out, dim, outs_idxs, ind * output_scanned_dim)
|
||||
|
||||
ind = op(ind)
|
||||
|
||||
return (carry, list(outs))
|
||||
|
||||
scans = _scan(init, xs)
|
||||
return scans
|
||||
|
||||
|
||||
def make_expanded_output_shape(dim, scan_length, shapes, use_sh=False):
|
||||
expanded_shapes = [
|
||||
tuple(
|
||||
(s if use_sh else -1) if i != dim else scan_length for i, s in enumerate(sh)
|
||||
)
|
||||
for sh in shapes
|
||||
]
|
||||
return expanded_shapes
|
||||
|
||||
|
||||
def trace_scan(
|
||||
proxy_mode,
|
||||
func_overload,
|
||||
combine_fn: Callable,
|
||||
init: List[torch.Tensor],
|
||||
xs: List[torch.Tensor],
|
||||
dim: int,
|
||||
reverse: bool,
|
||||
):
|
||||
with disable_proxy_modes_tracing():
|
||||
sample_inits = [
|
||||
torch.empty_like(
|
||||
x_init,
|
||||
dtype=x_init.dtype,
|
||||
device=x_init.device,
|
||||
requires_grad=x_init.requires_grad,
|
||||
)
|
||||
for x_init in init
|
||||
]
|
||||
sample_xs = [
|
||||
torch.empty_like(
|
||||
aten.slice(x, dim, 0, 1, 1),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
requires_grad=x.requires_grad,
|
||||
)
|
||||
for x in xs
|
||||
]
|
||||
combine_graph = reenter_make_fx(combine_fn)(*sample_inits, *sample_xs)
|
||||
|
||||
outputs = None
|
||||
for node in combine_graph.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert outputs is None
|
||||
assert len(node.args) == 1
|
||||
outputs = node.args[0]
|
||||
|
||||
assert outputs is not None
|
||||
if len(outputs) != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected to return 2 outputs: carry, out_matrix, but got:"
|
||||
f"\n {len(outputs)} elements"
|
||||
)
|
||||
|
||||
for ini, carry in zip(init, outputs[0]):
|
||||
ini_meta = ini
|
||||
carry_meta = carry.meta["tensor_meta"]
|
||||
carry_val = carry.meta["val"]
|
||||
if (
|
||||
carry_val.device != ini_meta.device
|
||||
or carry_meta.dtype != ini_meta.dtype
|
||||
or carry_meta.shape != ini_meta.shape
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Expected metadata of the combine_fn result {carry_meta} to be the same as "
|
||||
+ f"the metadata of init with {ini_meta}"
|
||||
)
|
||||
|
||||
_, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")
|
||||
|
||||
proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)
|
||||
|
||||
args = (combine_graph, init, xs, dim, reverse)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", func_overload, proxy_args, {}, name="scan"
|
||||
)
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
scan_length = xs[0].shape[dim]
|
||||
fake_out_shapes = make_expanded_output_shape(
|
||||
dim, scan_length, [o.meta["val"].size() for o in outputs[1]]
|
||||
)
|
||||
|
||||
def expand_tensor(t, sh):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.expand(*sh)
|
||||
return t
|
||||
|
||||
expanded_outs = [
|
||||
pytree.tree_map(expand_tensor, t.meta["val"], sh)
|
||||
for t, sh in zip(outputs[1], fake_out_shapes)
|
||||
]
|
||||
out = (init, expanded_outs)
|
||||
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
|
||||
|
||||
@scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def scan_op_dense(combine_fn, init, xs, dim, reverse):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
||||
return generic_scan(combine_fn, init, xs, dim, reverse)
|
||||
|
||||
|
||||
scan_op.py_impl(DispatchKey.Autograd)(
|
||||
autograd_not_implemented(scan_op, deferred_error=True)
|
||||
)
|
||||
|
||||
|
||||
@scan_op.py_impl(ProxyTorchDispatchMode)
|
||||
def scan_proxy_mode(mode, combine_fn, init, xs, dim, reverse):
|
||||
return trace_scan(mode, scan_op, combine_fn, init, xs, dim, reverse)
|
||||
|
||||
|
||||
@scan_op.py_impl(FakeTensorMode)
|
||||
def scan_fake_tensor_mode(mode, combine_fn, init, xs, dim, reverse):
|
||||
with mode:
|
||||
dim_len = xs[0].shape[dim]
|
||||
carry, outputs = combine_fn(
|
||||
*init, *[aten.slice(inp, dim, 0, 1, 1) for inp in xs]
|
||||
)
|
||||
fake_out_shapes = [
|
||||
tuple(-1 if i != dim else dim_len for i, sh in enumerate(o.size()))
|
||||
for o in outputs
|
||||
]
|
||||
out = (
|
||||
carry,
|
||||
tuple(t.expand(*sh).clone() for t, sh in zip(outputs, fake_out_shapes)),
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@scan_op.py_functionalize_impl
|
||||
def scan_functionalize(ctx, combine_fn, init, xs, dim, reverse):
|
||||
unwrapped_xs = ctx.unwrap_tensors(xs)
|
||||
unwrapped_init = ctx.unwrap_tensors(init)
|
||||
with ctx.redispatch_to_next() as m:
|
||||
functional_combine_fn = ctx.functionalize(combine_fn)
|
||||
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
||||
sample_xs = list(itertools.chain(unwrapped_init, unwrapped_init))
|
||||
if _has_potential_branch_input_mutation(
|
||||
functional_combine_fn, sample_xs, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"Combine_fn might be modifying the input!"
|
||||
)
|
||||
if _has_potential_branch_input_alias(
|
||||
functional_combine_fn, sample_xs, pre_dispatch=pre_dispatch
|
||||
):
|
||||
raise UnsupportedAliasMutationException(
|
||||
"Combine_fn might be aliasing the input!"
|
||||
)
|
||||
ret = scan_op(functional_combine_fn, unwrapped_init, unwrapped_xs, dim, reverse)
|
||||
return ctx.wrap_tensors(ret)
|
@ -224,6 +224,7 @@ def _from_fun(t):
|
||||
t.stride(),
|
||||
dtype=t.dtype,
|
||||
requires_grad=t.requires_grad,
|
||||
device=t.device,
|
||||
)
|
||||
else:
|
||||
# clone of a functional tensor produces a functional tensor
|
||||
|
@ -6218,12 +6218,12 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
|
||||
|
||||
@register_lowering(associative_scan_op, type_promotion_kind=None)
|
||||
def associative_scan(combine_fn: ir.Subgraph, input, dim: int):
|
||||
def associative_scan(combine_fn: ir.Subgraph, xs, dim: int):
|
||||
from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph
|
||||
|
||||
subgraph_inputs = [
|
||||
InputDescriptor(dtype=x.get_dtype(), device=x.get_device())
|
||||
for x in itertools.chain(input, input)
|
||||
for x in itertools.chain(xs, xs)
|
||||
]
|
||||
lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated]
|
||||
|
||||
@ -6233,9 +6233,9 @@ def associative_scan(combine_fn: ir.Subgraph, input, dim: int):
|
||||
*pytree.tree_leaves(rhs),
|
||||
)
|
||||
|
||||
kwargs = _make_scan_inner(input[0], axis=dim, dtype=None)
|
||||
kwargs["dtypes"] = tuple(x.get_dtype() for x in input)
|
||||
kwargs["inner_fns"] = tuple(x.make_loader() for x in input)
|
||||
kwargs = _make_scan_inner(xs[0], axis=dim, dtype=None)
|
||||
kwargs["dtypes"] = tuple(x.get_dtype() for x in xs)
|
||||
kwargs["inner_fns"] = tuple(x.make_loader() for x in xs)
|
||||
result = ir.Scan.create(
|
||||
combine_fn=wrapped_combine_fn,
|
||||
can_fallback_to_aten=False,
|
||||
|
Reference in New Issue
Block a user