Compare commits

...

2 Commits

Author SHA1 Message Date
3c7ffb1c63 scan attempt 2025-11-04 07:06:56 -08:00
ed5bd1431f [label_to_label] minor updates
vllm-compile implies "module: vllm" and "oncall: pt2".
The volume of issues in Flex -> HigherOrderOperators is too noisy,
plus we have a different set of folks looking at each, so I'm going to
make that not automatic anymore. We can still manually label flex issues
as higher order operator issues.

ghstack-source-id: 6715a72fb19030a0e28db2cda4fcb40fc04e3716
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166172
2025-10-24 11:44:17 -07:00
29 changed files with 178 additions and 10 deletions

View File

@ -15,6 +15,11 @@
- "module: reinplacing"
then:
- "module: pt2-dispatcher"
- any:
- "vllm-compile"
then:
- "module: vllm"
- "oncall: pt2"
- any:
- "module: vmap"
then:
@ -27,10 +32,6 @@
- "module: pt2 optimizer"
then:
- "module: dynamo"
- any:
- "module: flex attention"
then:
- "module: higher order operators"
- any:
- "module: aotinductor"
then:

View File

@ -3368,6 +3368,15 @@ class TestVmapOperators(Namespace.TestVmapBase):
@parametrize("out_dim", [0, 1, 2])
@parametrize("randomness", ["error", "same"])
def test_vmap_chunksize(self, in_dim, out_dim, randomness):
self._test_vmap_chunksize(in_dim, out_dim, randomness, chunk_with_scan=False)
@parametrize("in_dim", [0, 1, 2])
@parametrize("out_dim", [0])
@parametrize("randomness", ["error", "same"])
def test_vmap_chunk_with_scan_basic(self, in_dim, out_dim, randomness):
self._test_vmap_chunksize(in_dim, out_dim, randomness, chunk_with_scan=True)
def _test_vmap_chunksize(self, in_dim, out_dim, randomness, chunk_with_scan):
x = torch.randn(4, 5, 6)
y = torch.randn_like(x)
@ -3455,7 +3464,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
expected_vmap = vmap(fn, **kwargs)(*args)
for chunk_size in (1, 2, 3, 4, 7, 10, 16, 100):
torch.set_rng_state(rs)
output = vmap(fn, chunk_size=chunk_size, **kwargs)(*args)
output = vmap(fn, chunk_size=chunk_size, chunk_with_scan=chunk_with_scan, **kwargs)(*args)
self.assertEqual(output, expected_vmap)
@parametrize("in_dim", [0, 1])
@ -6045,8 +6054,7 @@ class TestRandomness(TestCase):
self._assert_all_slices_unique(output)
@parametrize("in_dim", [0, 1, 2])
@parametrize("out_dim", [0, 1, 2])
def test_vmap_chunksize(self, in_dim, out_dim):
def test_vmap_chunksize(self, in_dim):
randomness = "different"
x = torch.randn(4, 5, 6)
@ -6059,12 +6067,59 @@ class TestRandomness(TestCase):
output = vmap(
f,
in_dims=in_dim,
out_dims=out_dim,
randomness=randomness,
chunk_size=chunk_size,
)(x)
self._assert_all_slices_unique(output)
@parametrize("in_dim", [0, 1, 2])
def test_vmap_chunk_with_scan(self, in_dim):
randomness = "different"
x = torch.randn(4, 8, 16)
def f(x):
y = x.sin() + torch.rand_like(x)
return y
for chunk_size in [1, 2, 4]:
output = torch.vmap(
f,
in_dims=in_dim,
randomness=randomness,
chunk_size=chunk_size,
chunk_with_scan=True,
)(x)
self._assert_all_slices_unique(output)
@parametrize("in_dim1", [0, 1])
@parametrize("in_dim2", [0, 1])
def test_vmap_chunk_with_scan_nested(self, in_dim1, in_dim2):
randomness = "different"
x = torch.randn(4, 8, 16)
def f(x):
y = x.sin() + torch.rand_like(x)
return y
for chunk_size1 in [1, 2, 4]:
for chunk_size2 in [1, 2, 4]:
output = torch.vmap(
lambda x: torch.vmap(
f,
in_dims=in_dim2,
randomness=randomness,
chunk_size=chunk_size2,
chunk_with_scan=True,
)(x),
in_dims=in_dim1,
randomness=randomness,
chunk_size=chunk_size1,
chunk_with_scan=True,
)(x)
self._assert_all_slices_unique(output)
def test_jacfwd_with_random(self):
# checks on behavior are above, this just checks that jacfwd respects
# the randomness param

View File

@ -770,6 +770,7 @@ def validate_args_and_maybe_create_graph_inputs(
# If `a` cannot be put into a graph
else:
# HOPs work much better if they use speculate_subgraph(set_subgraph_inputs="automatic").
breakpoint()
unimplemented(
f"{description} with body that accepts non-Tensors as input. "
f"Got: {a.python_type()}"

View File

@ -35,6 +35,7 @@ def vmap(
randomness: str = "error",
*,
chunk_size=None,
chunk_with_scan=False,
) -> Callable:
"""
vmap is the vectorizing map; ``vmap(func)`` returns a new function that
@ -204,9 +205,14 @@ def vmap(
f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})"
)
if chunk_size is None and chunk_with_scan:
raise ValueError(
"vmap: chunk_with_scan can only be used when chunk_size is specified."
)
def wrapped(*args, **kwargs):
return vmap_impl(
func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
func, in_dims, out_dims, randomness, chunk_size, chunk_with_scan, *args, **kwargs
)
if not is_compiling():

View File

@ -63,6 +63,7 @@ def vmap(
randomness: str = "error",
*,
chunk_size=None,
chunk_with_scan=False,
) -> Callable:
warn_deprecated("vmap", "torch.vmap")
return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size)

View File

@ -26,6 +26,7 @@ from torch._functorch.predispatch import (
from torch.utils._pytree import (
_broadcast_to_and_flatten,
tree_flatten,
tree_map,
tree_map_,
tree_unflatten,
TreeSpec,
@ -258,7 +259,97 @@ def _get_name(func: Callable):
return repr(func)
def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):
def _flat_vmap_chunk_with_scan(
func,
batch_size,
flat_in_dims,
flat_args,
args_spec,
chunk_size,
randomness,
out_dims,
**kwargs,
):
if batch_size % chunk_size != 0:
# TODO: support padding
raise NotImplementedError(
f"vmap(chunk_with_scan=True): batch_size ({batch_size}) must be divisible by chunk_size ({chunk_size})"
)
if out_dims != 0:
raise NotImplementedError(
f"vmap(chunk_with_scan=True): we only support out_dims=0 for now, got {out_dims}"
)
num_chunks = batch_size // chunk_size
# strategy: Overall, we're going to do a scan(vmap(f), ...).
#
# We're going to create a scan dimension by:
# - move all the in_dims to the front.
# - Then split B into (chunk_size, num_chunks, *).
# - Then scan over the chunk_size.
#
# We need to additionally split tensors into ones that have
# bdim and ones that don't. We're only going to scan
# over the tensors that do have bdim (that are being vmapped over)
# The other tensors we're going to implicitly capture in the function
# to be scanned over.
def reshape_for_scan(x, in_dim):
x = x.movedim(in_dim, 0)
x = x.reshape(chunk_size, num_chunks, *x.shape[1:])
return x
# Split into scanned vs unscanned args
flat_scanned_args = []
flat_unscanned_args = []
for in_dim, arg in zip(flat_in_dims, flat_args):
if in_dim is None:
flat_unscanned_args.append(arg)
else:
flat_scanned_args.append(reshape_for_scan(arg, in_dim))
def func_to_scan(dummy, flat_scanned_args):
flat_all_args = []
flat_scanned_args_it = iter(flat_scanned_args)
flat_unscanned_args_it = iter(flat_unscanned_args)
for in_dim in flat_in_dims:
if in_dim is None:
new_arg = next(flat_unscanned_args_it)
else:
new_arg = next(flat_scanned_args_it)
flat_all_args.append(new_arg)
return dummy.clone(), _flat_vmap(
func,
batch_size,
flat_in_dims,
flat_all_args,
args_spec,
out_dims,
randomness,
**kwargs,
)
from torch._higher_order_ops import scan
# scan requires a dummy tensor :/
_, result = scan(func_to_scan, torch.zeros(0), flat_scanned_args)
# We assume out_dims=0, so the result looks like:
# (scan_dim, bdim, *).
# Flatten the first two dimensions together.
#
# We can support other out_dims, (we're going to need to do the broadcast_to_and_dims thing)
# but the real annoying case is out_dims=None.
return tree_map(lambda x: x.flatten(0, 1), result)
def vmap_impl(
func, in_dims, out_dims, randomness, chunk_size, chunk_with_scan, *args, **kwargs
):
lazy_load_decompositions()
_check_out_dims_is_int_or_int_pytree(out_dims, func)
batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
@ -266,6 +357,19 @@ def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):
)
if chunk_size is not None:
if chunk_with_scan:
return _flat_vmap_chunk_with_scan(
func,
batch_size,
flat_in_dims,
flat_args,
args_spec,
chunk_size,
randomness,
out_dims,
**kwargs,
)
chunks_flat_args = _get_chunked_inputs(
flat_args, flat_in_dims, batch_size, chunk_size
)