Enhance torch.vmap support from inside torch.compile (#116050)

This work rewrites vmap support in torch.compile by inlining most of
the frames into the existing FX graph. It also unlocks to PyTorch to
support features that were previously missing, such as keyword args.

Fixes: https://github.com/pytorch/pytorch/issues/114306

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116050
Approved by: https://github.com/zou3519
This commit is contained in:
Guilherme Leobas
2024-01-19 12:36:51 -03:00
committed by PyTorch MergeBot
parent b2a3d6ba0d
commit 80cf0ce153
18 changed files with 642 additions and 412 deletions

View File

@ -34,6 +34,8 @@ per-file-ignores =
torch/utils/cpp_extension.py: B950
torchgen/api/types/__init__.py: F401,F403
torchgen/executorch/api/types/__init__.py: F401,F403
test/dynamo/test_higher_order_ops.py: B950
torch/testing/_internal/dynamo_test_failures.py: B950
optional-ascii-coding = True
exclude =
./.git,

View File

@ -4,6 +4,7 @@ import functools
import pprint
import re
import unittest
import warnings
import functorch.experimental.control_flow as control_flow
@ -24,7 +25,12 @@ from torch._dynamo.testing import (
)
from torch._dynamo.utils import counters, ifdynstaticdefault
from torch._higher_order_ops.wrap import wrap
from torch.testing._internal.common_utils import (
TEST_WITH_TORCHDYNAMO,
xfailIfTorchDynamo,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
@ -2272,13 +2278,13 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3, 3, 3, 3)
fn(x)
gm = backend.graphs[0]
actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sum_2", "add"})
actual_stack = self._get_source_fn_stack(
gm, {"sum_1", "sum_2", "batched_output"}
)
self.assertExpectedInline(
pprint.pformat(actual_stack),
"""\
{'add': ['vmap_impl', 'vmap_impl', 'add'],
'sum_1': ['vmap_impl', 'vmap_impl', 'sum_1'],
'sum_2': ['vmap_impl', 'vmap_impl', 'sum_2']}""",
{'batched_output': ['add'], 'sum_1': ['sum_1'], 'sum_2': ['sum_2']}""",
)
def test_cond_pytree_operands(self):
@ -2370,7 +2376,82 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
torch.compile(fn, backend="eager")(pred, pytree_in)
class HigherOrderOpVmapGuardTests(LoggingTestCase):
@config.patch(capture_func_transforms=True)
@make_logging_test(recompiles=True)
def test_vmap_guard_ok(self, records):
@torch.compile(backend="eager")
def fn(x):
return torch.vmap(lambda x: x.sin())(x)
x = torch.randn(3, 3, 4, 5)
y = fn(x)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(x.sin(), y)
# Calling the same function again won't have any effect on guards
z = fn(x)
self.assertEqual(len(records), 0)
self.assertEqual(x.sin(), z)
# calling with a different object will also not affect guards
w = fn(z)
self.assertEqual(len(records), 0)
self.assertEqual(z.sin(), w)
@xfailIfTorchDynamo
@config.patch(capture_func_transforms=True)
@make_logging_test(recompiles=True)
def test_vmap_guard_fail(self, records):
@torch.compile(backend="eager")
def fn(x):
return torch.vmap(lambda x: x.sin())(x)
x = torch.zeros(3, 3, 4, 5)
y = torch.vmap(fn)(x)
self.assertEqual(x.sin(), y)
self.assertEqual(len(records), 0)
# call vmap(vmap(fn))(x) should retrigger compilation as
# _functorch.current_level() is not the same
x = torch.zeros(3, 3, 3, 4, 5)
y = torch.vmap(torch.vmap(fn))(x)
self.assertEqual(x.sin(), y)
self.assertGreater(len(records), 0)
record = self.getRecord(records, "maybe_current_level()")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch._C._functorch.maybe_current_level() is None # with vmap_increment_nesting(batch_size, randomness) as vmap_level: # _functorch/vmap.py:399 in _flat_vmap""",
record.getMessage(),
)
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
def tearDown(self):
# Ensure that in the case of a test failure, the next test won't fail
# because of a previous call to _vmap_increment_nesting that wasn't undone
# i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1
# and the call to increment nesting is not undone
if not TEST_WITH_TORCHDYNAMO:
return
warn = False
while ci := torch._C._functorch.peek_interpreter_stack():
if ci.key() == torch._C._functorch.TransformType.Vmap:
warn = True
torch._C._functorch._vmap_decrement_nesting()
else:
break
if warn:
msg = (
"Interpreter stack is not empty. Test should have called "
"'torch._C._functorch._vmap_decrement_nesting()'"
)
warnings.warn(msg)
def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0):
backend = EagerAndRecordGraphs()
actual = fn(*inputs)
@ -2946,6 +3027,112 @@ class GraphModule(torch.nn.Module):
)
self.assertEqual(actual, expected)
@config.patch(capture_func_transforms=True)
def test_vmap_get_wrapped(self):
counters.clear()
def g(x):
return x.sin()
@torch.compile(backend="aot_eager", fullgraph=True)
def fn():
return torch.vmap(g)
x = torch.randn(3, 4)
expected = torch.vmap(g)(x)
wrapper = fn()
got = wrapper(x)
self.assertEqual(expected, got)
@config.patch(capture_func_transforms=True)
def test_vmap_with_conditional_graph_break(self):
def g(x):
if len(x.shape) < 2:
torch._dynamo.graph_break()
return x.sin()
else:
return x.cos()
@torch.compile(backend="aot_eager")
def fn(x):
return torch.vmap(g)(x)
counters.clear()
x = torch.randn(2, 3)
expected = x.sin()
got = fn(x)
self.assertEqual(expected, got)
self.assertEqual(len(counters["graph_break"]), 1)
counters.clear()
y = torch.randn(2, 3, 4)
expected = y.cos()
got = fn(y)
self.assertEqual(expected, got)
self.assertEqual(len(counters["graph_break"]), 0)
@config.patch(capture_func_transforms=True)
def test_vmap_with_graph_break(self):
counters.clear()
def g(x):
y = x.cos()
print("hi")
return y.sin()
def fn(x):
return torch.vmap(g)(x)
x = torch.randn(3, 4)
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
expected = fn(x)
got = opt(x)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(expected, got)
@config.patch(capture_func_transforms=True)
def test_vmap_with_graph_break_2(self):
counters.clear()
def cos(x):
print("cos")
return x.cos()
def sin(x):
print("sin")
return x.sin()
def g(x):
y = cos(x)
return sin(y)
def fn(x):
return torch.vmap(g, randomness="same")(x)
x = torch.randn(3, 4)
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
expected = fn(x)
got = opt(x)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(expected, got)
def test_vmap_with_graph_break_lambda(self):
counters.clear()
def sin(x):
print("sin")
return x.sin()
def fn(x):
return torch.vmap(lambda x: sin(x))(x)
x = torch.randn(3, 4)
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
expected = fn(x)
got = opt(x)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(expected, got)
@config.patch(capture_func_transforms=True)
def test_vmap(self):
def fn(x):
@ -2964,22 +3151,24 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
child = L_x_
arg = L_x_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
select = child.select(0, 0)
vmap_body_0 = self.vmap_body_0
vmap_proxy = torch.func.vmap(vmap_body_0, (0,), 0, 'error'); vmap_body_0 = None
call = vmap_proxy.__call__(child); vmap_proxy = child = None
return (call,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
class GraphModule(torch.nn.Module):
def forward(self, select):
sum_1 = select.sum(0)
sum_2 = select.sum(1); select = None
add = sum_1 + sum_2; sum_1 = sum_2 = None
return add
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
sum_1 = _add_batch_dim.sum(0)
sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
batched_output = sum_1 + sum_2; sum_1 = sum_2 = None
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_remove_batch_dim,)
""",
)
@ -3003,23 +3192,25 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
child = L_x_
arg = L_x_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
select = child.select(0, 0)
vmap_body_0 = self.vmap_body_0
vmap_proxy = torch.func.vmap(vmap_body_0, (0,), 0, 'error'); vmap_body_0 = None
call = vmap_proxy.__call__(child); vmap_proxy = child = None
return (call,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
class GraphModule(torch.nn.Module):
def forward(self, select):
sum_1 = select.sum(0)
sum_2 = select.sum(1); select = None
add = sum_1 + sum_2; sum_1 = sum_2 = None
add_1 = add + 3; add = None
return add_1
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
sum_1 = _add_batch_dim.sum(0)
sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
add = sum_1 + sum_2; sum_1 = sum_2 = None
batched_output = add + 3; add = None
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_remove_batch_dim,)
""",
)
@ -3043,24 +3234,26 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
child = L_x_
arg = L_x_
l_y_ = L_y_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
select = child.select(0, 0)
vmap_body_0 = self.vmap_body_0
vmap_proxy = torch.func.vmap(vmap_body_0, (0, None), 0, 'error'); vmap_body_0 = None
call = vmap_proxy.__call__(child, l_y_); vmap_proxy = child = l_y_ = None
return (call,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
class GraphModule(torch.nn.Module):
def forward(self, select, l_y_):
sum_1 = select.sum(0)
sum_2 = select.sum(1); select = None
add = sum_1 + sum_2; sum_1 = sum_2 = None
add_1 = add + l_y_; add = l_y_ = None
return add_1
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
sum_1 = _add_batch_dim.sum(0)
sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
add = sum_1 + sum_2; sum_1 = sum_2 = None
batched_output = add + l_y_; add = l_y_ = None
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_remove_batch_dim,)
""",
)
@ -3085,25 +3278,27 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
child = L_x_
child_1 = L_y_
arg = L_x_
arg_3 = L_y_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
select = child.select(0, 0)
select_1 = child_1.select(1, 0)
vmap_body_0 = self.vmap_body_0
vmap_proxy = torch.func.vmap(vmap_body_0, (0, 1), 0, 'error'); vmap_body_0 = None
call = vmap_proxy.__call__(child, child_1); vmap_proxy = child = child_1 = None
return (call,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
class GraphModule(torch.nn.Module):
def forward(self, select, select_1):
sum_1 = select.sum(0)
sum_2 = select.sum(1); select = None
add = sum_1 + sum_2; sum_1 = sum_2 = None
add_1 = add + select_1; add = select_1 = None
return add_1
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
_add_batch_dim_1 = torch._C._functorch._add_batch_dim(arg_3, 1, 1); arg_3 = None
sum_1 = _add_batch_dim.sum(0)
sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
add = sum_1 + sum_2; sum_1 = sum_2 = None
batched_output = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_remove_batch_dim,)
""",
)
@ -3130,25 +3325,27 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
child = L_x_
child_1 = L_y_
arg = L_x_
arg_3 = L_y_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
select = child.select(0, 0)
select_1 = child_1.select(1, 0)
vmap_body_0 = self.vmap_body_0
vmap_proxy = torch.func.vmap(vmap_body_0, (0, 1), 0, 'error'); vmap_body_0 = None
call = vmap_proxy.__call__(child, child_1); vmap_proxy = child = child_1 = None
return (call,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
class GraphModule(torch.nn.Module):
def forward(self, select, select_1):
sum_1 = select.sum(0)
sum_2 = select.sum(1); select = None
add = sum_1 + sum_2; sum_1 = sum_2 = None
add_1 = add + select_1; add = select_1 = None
return add_1
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
_add_batch_dim_1 = torch._C._functorch._add_batch_dim(arg_3, 1, 1); arg_3 = None
sum_1 = _add_batch_dim.sum(0)
sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
add = sum_1 + sum_2; sum_1 = sum_2 = None
batched_output = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_remove_batch_dim,)
""",
)
@ -3171,32 +3368,37 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
child = L_x_
child_1 = L_y_
arg = L_x_
arg_3 = L_y_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
_check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error')
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
select = child.select(0, 0)
select_1 = child_1.select(0, 0)
vmap_body_1 = self.vmap_body_1
vmap_proxy = torch.func.vmap(vmap_body_1, (0, 0), 0, 'error'); vmap_body_1 = None
call = vmap_proxy.__call__(child, child_1); vmap_proxy = child = child_1 = None
return (call,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
class GraphModule(torch.nn.Module):
def forward(self, select, select_1):
select_2 = select.select(1, 0)
select_3 = select_1.select(1, 0)
vmap_body_0 = self.vmap_body_0
vmap_proxy = torch.func.vmap(vmap_body_0, (1, 1), 0, 'error'); vmap_body_0 = None
call = vmap_proxy.__call__(select, select_1); vmap_proxy = select = select_1 = None
return call
arg_8 = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
arg_9 = torch._C._functorch._add_batch_dim(arg_3, 0, 1); arg_3 = None
class GraphModule(torch.nn.Module):
def forward(self, select_2, select_3):
add = select_2 + select_3; select_2 = select_3 = None
return add
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions()
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error')
_add_batch_dim_2 = torch._C._functorch._add_batch_dim(arg_8, 1, 2); arg_8 = None
_add_batch_dim_3 = torch._C._functorch._add_batch_dim(arg_9, 1, 2); arg_9 = None
batched_output = _add_batch_dim_2 + _add_batch_dim_3; _add_batch_dim_2 = _add_batch_dim_3 = None
batched_output_1 = torch._C._functorch._remove_batch_dim(batched_output, 2, 3, 0); batched_output = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 3, 0); batched_output_1 = None
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_remove_batch_dim_1,)
""",
)
@ -3220,30 +3422,35 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
child = L_y_
arg = L_y_
l_x_ = L_x_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
_check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error')
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
select = child.select(0, 0)
vmap_body_1 = self.vmap_body_1
vmap_proxy = torch.func.vmap(vmap_body_1, (0, None), 0, 'error'); vmap_body_1 = None
call = vmap_proxy.__call__(child, l_x_); vmap_proxy = child = l_x_ = None
return (call,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error')
class GraphModule(torch.nn.Module):
def forward(self, select, l_x_):
select_1 = select.select(0, 0)
vmap_body_0 = self.vmap_body_0
vmap_proxy = torch.func.vmap(vmap_body_0, (0, None), 0, 'error'); vmap_body_0 = None
call = vmap_proxy.__call__(select, l_x_); vmap_proxy = select = l_x_ = None
return call
arg_3 = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
class GraphModule(torch.nn.Module):
def forward(self, select_1, l_x_):
mul = l_x_ * select_1; l_x_ = select_1 = None
return mul
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions()
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error')
_add_batch_dim_1 = torch._C._functorch._add_batch_dim(arg_3, 0, 2); arg_3 = None
batched_output = l_x_ * _add_batch_dim_1; l_x_ = _add_batch_dim_1 = None
batched_output_1 = torch._C._functorch._remove_batch_dim(batched_output, 2, 3, 0); batched_output = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 5, 0); batched_output_1 = None
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_remove_batch_dim_1,)
""",
)
@ -3266,23 +3473,24 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
child = L_x_
arg = L_x_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
select = child.select(0, 0)
vmap_body_0 = self.vmap_body_0
vmap_proxy = torch.func.vmap(vmap_body_0, (0,), 0, 'error'); vmap_body_0 = None
call = vmap_proxy.__call__(child); vmap_proxy = child = None
getitem = call[0]
getitem_1 = call[1]; call = None
return (getitem, getitem_1)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error')
class GraphModule(torch.nn.Module):
def forward(self, select):
sum_1 = select.sum(0)
sum_2 = select.sum(1); select = None
return (sum_1, sum_2)
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
batched_output = _add_batch_dim.sum(0)
batched_output_1 = _add_batch_dim.sum(1); _add_batch_dim = None
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 2, 0); batched_output = None
_remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 2, 0); batched_output_1 = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_remove_batch_dim, _remove_batch_dim_1)
""",
)
@ -3305,23 +3513,24 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
child = L_x_
arg = L_x_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
select = child.select(0, 0)
vmap_body_0 = self.vmap_body_0
vmap_proxy = torch.func.vmap(vmap_body_0, (0,), (1, 0), 'error'); vmap_body_0 = None
call = vmap_proxy.__call__(child); vmap_proxy = child = None
getitem = call[0]
getitem_1 = call[1]; call = None
return (getitem, getitem_1)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error')
class GraphModule(torch.nn.Module):
def forward(self, select):
sum_1 = select.sum(0)
sum_2 = select.sum(1); select = None
return (sum_1, sum_2)
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
batched_output = _add_batch_dim.sum(0)
batched_output_1 = _add_batch_dim.sum(1); _add_batch_dim = None
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 2, 1); batched_output = None
_remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 2, 0); batched_output_1 = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_remove_batch_dim, _remove_batch_dim_1)
""",
)
@ -3345,23 +3554,24 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
child = L_x_
arg = L_x_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
select = child.select(0, 0)
vmap_body_0 = self.vmap_body_0
vmap_proxy = torch.func.vmap(vmap_body_0, (0,), (1, 0), 'error'); vmap_body_0 = None
call = vmap_proxy.__call__(child); vmap_proxy = child = None
getitem = call[0]
getitem_1 = call[1]; call = None
return (getitem, getitem_1)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error')
class GraphModule(torch.nn.Module):
def forward(self, select):
sum_1 = select.sum(0)
sum_2 = select.sum(1); select = None
return (sum_1, sum_2)
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
batched_output = _add_batch_dim.sum(0)
batched_output_1 = _add_batch_dim.sum(1); _add_batch_dim = None
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 2, 1); batched_output = None
_remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 2, 0); batched_output_1 = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_remove_batch_dim, _remove_batch_dim_1)
""",
)
@ -3376,11 +3586,7 @@ class GraphModule(torch.nn.Module):
actual = fn(x, y)
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(
dict(counters["graph_break"]),
{"NYI - torch.func.vmap: kwargs arguments are currently unsupported.": 2},
)
self.assertEqual(len(counters["graph_break"]), 0)
self.assertEqual(actual, expected)
@config.patch(capture_func_transforms=True)
@ -3399,14 +3605,7 @@ class GraphModule(torch.nn.Module):
actual = fn(x, y)
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
self.assertEqual(len(counters["graph_break"]), 1)
assert_dict_matches_regex(
self,
dict(counters["graph_break"]),
{
".*torch.vmap with body that accepts non-Tensors as input": 2,
},
)
self.assertEqual(len(counters["graph_break"]), 0)
self.assertEqual(actual, expected)
@config.patch(capture_func_transforms=True)
@ -3426,14 +3625,48 @@ class GraphModule(torch.nn.Module):
actual = wrapper_fn(x, y)
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
self.assertEqual(len(counters["graph_break"]), 1)
assert_dict_matches_regex(
self,
dict(counters["graph_break"]),
{
r".*HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)": 2
},
)
self.assertEqual(len(counters["graph_break"]), 0)
self.assertEqual(actual, expected)
self.assertEqual(some_list, [1, 1])
@unittest.expectedFailure
@config.patch(capture_func_transforms=True)
def test_vmap_side_effects_append_input(self):
counters.clear()
x = torch.ones(2, 3)
y = torch.randn(2, 3)
some_list = []
def f(x, y):
some_list.append(x)
return x + y
def wrapper_fn(x, y):
return torch.func.vmap(f)(x, y)
actual = wrapper_fn(x, y)
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
self.assertEqual(len(counters["graph_break"]), 0)
self.assertEqual(actual, expected)
@config.patch(capture_func_transforms=True)
def test_vmap_previous_illegal_op_no_graph_break(self):
counters.clear()
# calling .stride() would previously graph break
def bad_fn(x):
y = x.view((4, 3))
y.stride()
return y
def wrapper_fn(x):
return torch.func.vmap(bad_fn)(x)
x = torch.randn(2, 3, 4)
actual = wrapper_fn(x)
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
self.assertEqual(len(counters["graph_break"]), 0)
self.assertEqual(actual, expected)
@config.patch(capture_func_transforms=True)
@ -3461,28 +3694,6 @@ class GraphModule(torch.nn.Module):
)
self.assertEqual(actual, expected)
@config.patch(capture_func_transforms=True)
def test_vmap_illegal_op_graph_break(self):
counters.clear()
def bad_fn(x):
x.stride()
return x
def wrapper_fn(x):
return torch.func.vmap(bad_fn)(x)
x = torch.randn(3, 3, 3)
actual = wrapper_fn(x)
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
self.assertEqual(len(counters["graph_break"]), 1)
assert_dict_matches_regex(
self,
dict(counters["graph_break"]),
{".*Illegal getattr invocation stride in strict mode": 2},
)
self.assertEqual(actual, expected)
@config.patch(capture_func_transforms=True)
def test_vmap_multiple_invocation_in_dims(self):
counters.clear()
@ -3498,7 +3709,7 @@ class GraphModule(torch.nn.Module):
actual = opt(x, 0), opt(x, 1), opt(x, 2)
self.assertEqual(expected, actual)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 9)
self.assertEqual(cnt.op_count, 33)
@config.patch(capture_func_transforms=True)
def test_vmap_multiple_invocation_out_dims(self):
@ -3515,7 +3726,7 @@ class GraphModule(torch.nn.Module):
actual = opt(x, 0), opt(x, 1), opt(x, 2)
self.assertEqual(expected, actual)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 9)
self.assertEqual(cnt.op_count, 30)
@config.patch(capture_func_transforms=True)
def test_vmap_new_tensor_in_body(self):
@ -3528,7 +3739,7 @@ class GraphModule(torch.nn.Module):
x = torch.randn(
3,
)
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=True)
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
expected = wrapper_fn(x)
actual = opt(x)
self.assertEqual(expected, actual)
@ -3542,7 +3753,7 @@ class GraphModule(torch.nn.Module):
return torch.func.vmap(fn)(x)
x = torch.randn(3)
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=True)
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
expected = wrapper_fn(x)
actual = opt(x)
self.assertEqual(expected, actual)
@ -3553,7 +3764,7 @@ class GraphModule(torch.nn.Module):
return torch.func.vmap(lambda t: torch.add(t, 0.5))(x)
x = torch.randn(3)
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=True)
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
expected = wrapper_fn(x)
actual = opt(x)
self.assertEqual(expected, actual)

View File

@ -13,6 +13,7 @@ def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
def maybe_get_bdim(tensor: Tensor) -> int: ...
def maybe_get_level(tensor: Tensor) -> int: ...
def maybe_current_level() -> Optional[int]: ...
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...

View File

@ -360,6 +360,15 @@ class GuardBuilder(GuardBuilderBase):
self._produce_guard_code(guard, [code], provided_guarded_object=self.get(base))
def FUNCTORCH_CURRENT_LEVEL_MATCH(self, guard: Guard):
# Invalidate the graph if a call to vmap has been made prior to this
# This is super conservative as the interpreter stack may not contain
# vmap
code = [
"torch._C._functorch.maybe_current_level() is None",
]
self._produce_guard_code(guard, code)
def EQUALS_MATCH(self, guard: Guard):
ref = self.arg_ref(guard)
val = self.get(guard.name)

View File

@ -124,7 +124,6 @@ BUILTIN_SKIPLIST = (
# third party libraries skiplist is defined by str, because users may not use these libraries.
# we should use lazy import & skip in the future.
THIRDPARTY_SKIPLIST = (
"functorch",
"fx2trt_oss",
"networkx",
"numpy",
@ -201,6 +200,7 @@ MOD_INLINELIST = {
"torch._dynamo._trace_wrapped_higher_order_op",
"torch._dynamo.comptime",
"torch._dynamo.polyfill",
"torch._functorch.vmap",
"torch._inductor.test_operators",
"torch.amp.autocast_mode",
"torch.ao.nn",

View File

@ -2057,6 +2057,8 @@ class InstructionTranslator(InstructionTranslatorBase):
speculation_log=speculation_log,
)
self._throw_if_in_vmap()
# as soon as we create the tracing context we should keep it active, so any calls
# into dynamo apis can rely on finding it
with tracing(self.output.tracing_context), self.set_current_tx():
@ -2093,6 +2095,22 @@ class InstructionTranslator(InstructionTranslatorBase):
if name in f_locals:
self._freevars_ids[name] = id(f_locals[name])
def _throw_if_in_vmap(self):
# Fallback to eager in case of a graph break inside vmap
eager = torch._dynamo.lookup_backend("eager")
compiler_fn = inspect.getattr_static(
self.output.compiler_fn, "compiler_fn", self.output.compiler_fn
)
ci = torch._C._functorch.peek_interpreter_stack()
if (
ci is not None
and ci.key() == torch._C._functorch.TransformType.Vmap
and compiler_fn is not eager
):
# if it reaches here, it means Dynamo failed to inline vmap
msg = "torch.vmap(fn) requires the function to be inlined by dynamo"
unimplemented(msg)
def get_example_value(self, source: Source):
if isinstance(source, LocalSource):
return self.f_locals[source.local_name]

View File

@ -20,9 +20,11 @@ import torch
from .utils import hashable, is_function, NP_SUPPORTED_MODULES
from .variables import (
FunctorchVmapHigherOrderVariable,
SkipFilesVariable,
TorchCtxManagerClassVariable,
TorchInGraphFunctionVariable,
UserFunctionVariable,
)
from .variables.base import VariableTracker
@ -132,6 +134,31 @@ manual_torch_name_rule_map = {
"torch.resize_as_": SkipFilesVariable,
"torch.resize_as_sparse_": SkipFilesVariable,
"torch.get_default_device": TorchInGraphFunctionVariable,
# functorch
"torch._functorch.vmap._check_int_or_none": UserFunctionVariable,
"torch._functorch.vmap._check_out_dims_is_int_or_int_pytree": UserFunctionVariable,
"torch._functorch.vmap._check_randomness_arg": UserFunctionVariable,
"torch._functorch.vmap._chunked_vmap": UserFunctionVariable,
"torch._functorch.vmap._concat_chunked_outputs": UserFunctionVariable,
"torch._functorch.vmap._create_batched_inputs": UserFunctionVariable,
"torch._functorch.vmap._flat_vmap": UserFunctionVariable,
"torch._functorch.vmap._flatten_chunks_output": UserFunctionVariable,
"torch._functorch.vmap._get_chunked_inputs": UserFunctionVariable,
"torch._functorch.vmap._get_name": UserFunctionVariable,
"torch._functorch.vmap._maybe_remove_batch_dim": UserFunctionVariable,
"torch._functorch.vmap._num_outputs": UserFunctionVariable,
"torch._functorch.vmap._process_batched_inputs": UserFunctionVariable,
"torch._functorch.vmap._unwrap_batched": UserFunctionVariable,
"torch._functorch.vmap._validate_and_get_batch_size": UserFunctionVariable,
"torch._functorch.vmap.doesnt_support_saved_tensors_hooks": UserFunctionVariable,
"torch._functorch.vmap.get_chunk_sizes": UserFunctionVariable,
# lazy_load_decompositions uses a lock that is not supported yet in dynamo
# "torch._functorch.vmap.lazy_load_decompositions": UserFunctionVariable,
"torch._functorch.vmap.restore_vmap": UserFunctionVariable,
"torch._functorch.apis.vmap": UserFunctionVariable,
"torch._functorch.vmap.unwrap_batched": UserFunctionVariable,
"torch._functorch.vmap.vmap_impl": FunctorchVmapHigherOrderVariable,
"torch._functorch.vmap.wrap_batched": UserFunctionVariable,
}
@ -140,6 +167,7 @@ torch_ctx_manager_classes = {
k: TorchCtxManagerClassVariable
for k in [
"torch._C.DisableTorchFunctionSubclass",
"torch._functorch.vmap.vmap_increment_nesting",
"torch.amp.autocast_mode.autocast",
"torch.autograd.grad_mode.enable_grad",
"torch.autograd.grad_mode.inference_mode",
@ -396,6 +424,13 @@ torch_c_binding_in_graph_functions = {
"torch._C._fft.fft_rfft2",
"torch._C._fft.fft_rfftfreq",
"torch._C._fft.fft_rfftn",
"torch._C._functorch._add_batch_dim",
"torch._C._functorch._remove_batch_dim",
"torch._C._functorch._vmap_incr_nest",
"torch._C._functorch._vmap_decr_nest",
"torch._C._functorch._vmap_increment_nesting",
"torch._C._functorch._vmap_decrement_nesting",
"torch._C._functorch.is_batchedtensor",
"torch._C._free_And_Remove_DeleterFn",
"torch._C._freeze_module",
"torch._C._from_dlpack",
@ -2135,29 +2170,7 @@ torch_non_c_binding_in_graph_functions = {
"torch._functorch.utils.enable_single_level_autograd_function",
"torch._functorch.utils.exposed_in",
"torch._functorch.utils.unwrap_dead_wrappers",
"torch._functorch.vmap._as_tuple",
"torch._functorch.vmap._check_int_or_none",
"torch._functorch.vmap._check_out_dims_is_int_or_int_pytree",
"torch._functorch.vmap._check_randomness_arg",
"torch._functorch.vmap._chunked_vmap",
"torch._functorch.vmap._concat_chunked_outputs",
"torch._functorch.vmap._create_batched_inputs",
"torch._functorch.vmap._flat_vmap",
"torch._functorch.vmap._flatten_chunks_output",
"torch._functorch.vmap._get_chunked_inputs",
"torch._functorch.vmap._get_name",
"torch._functorch.vmap._maybe_remove_batch_dim",
"torch._functorch.vmap._num_outputs",
"torch._functorch.vmap._process_batched_inputs",
"torch._functorch.vmap._unwrap_batched",
"torch._functorch.vmap._validate_and_get_batch_size",
"torch._functorch.vmap.doesnt_support_saved_tensors_hooks",
"torch._functorch.vmap.get_chunk_sizes",
"torch._functorch.vmap.lazy_load_decompositions",
"torch._functorch.vmap.restore_vmap",
"torch._functorch.vmap.unwrap_batched",
"torch._functorch.vmap.vmap_impl",
"torch._functorch.vmap.wrap_batched",
"torch._guards.compile_context",
"torch._guards.detect_fake_mode",
"torch._guards.tracing",

View File

@ -9,6 +9,7 @@ from .ctx_manager import (
InferenceModeVariable,
StreamContextVariable,
StreamVariable,
VmapIncrementNestingCtxManagerVariable,
WithExitFunctionVariable,
)
from .dicts import (
@ -23,7 +24,10 @@ from .functions import (
UserFunctionVariable,
UserMethodVariable,
)
from .higher_order_ops import TorchHigherOrderOperatorVariable
from .higher_order_ops import (
FunctorchVmapHigherOrderVariable,
TorchHigherOrderOperatorVariable,
)
from .iter import (
CountIteratorVariable,
CycleIteratorVariable,

View File

@ -1507,6 +1507,9 @@ def wrap_fx_proxy_cls(
torch._utils._element_size,
torch.seed,
operator.mod,
torch._C._functorch._vmap_increment_nesting,
torch._C._functorch._vmap_decrement_nesting,
torch._functorch.vmap._validate_and_get_batch_size,
# some mac builds are missing torch.distributed.get_rank()
getattr(torch.distributed, "get_rank", _missing),
getattr(torch.distributed, "get_world_size", _missing),

View File

@ -106,6 +106,7 @@ class BuiltinVariable(VariableTracker):
chr,
divmod,
float,
getattr,
int,
len,
max,

View File

@ -148,6 +148,48 @@ class GenericContextWrappingVariable(ContextWrappingVariable):
return x
class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
"""represents torch VMap increment/decrement nesting"""
# A guard is needed as the vmap level is baked into the torch FX graph
# generated. This is fine if vmap is only called from within the function
# being compiled. But the FX graph may be invalid in the case of a vmap
# call from eager that calls the compiled function, as the vmap levels
# may be different.
_guards_singleton = Guard(
GlobalStateSource(), GuardBuilder.FUNCTORCH_CURRENT_LEVEL_MATCH
)
@staticmethod
def create(tx, target_values, **kwargs):
var = VmapIncrementNestingCtxManagerVariable(
target_values=target_values,
initial_values=None,
**kwargs,
)
return var
def enter(self, tx):
install_guard(self._guards_singleton)
batch_size, randomness = self.target_values
vmap_level = torch._C._functorch._vmap_increment_nesting(batch_size, randomness)
self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
self.state.proxy = tx.output.create_node(
"call_function",
torch._C._functorch._vmap_increment_nesting,
(batch_size, randomness),
{},
)
return variables.ConstantVariable.create(vmap_level)
def exit(self, tx, *args):
self.state.cleanup()
tx.output.create_node(
"call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
)
return variables.ConstantVariable.create(None)
class GradModeVariable(ContextWrappingVariable):
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""

View File

@ -105,6 +105,13 @@ class BaseUserFunctionVariable(VariableTracker):
class UserFunctionVariable(BaseUserFunctionVariable):
"""Some unsupported user-defined global function"""
@classmethod
def create_with_source(cls, value, source):
return cls(
value,
source=source,
)
def __init__(self, fn, is_constant=False, **kwargs):
super().__init__(**kwargs)
if getattr(fn, "_dynamo_marked_constant", False):
@ -254,6 +261,10 @@ class UserFunctionVariable(BaseUserFunctionVariable):
def export_freevars(self, parent, child):
pass
def call_hasattr(self, tx, name: str) -> VariableTracker:
result = hasattr(self.fn, name)
return variables.ConstantVariable.create(result)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":

View File

@ -1,6 +1,5 @@
import contextlib
import functools
import itertools
import logging
import types
@ -10,7 +9,6 @@ import torch._C
import torch.fx
import torch.nn
import torch.onnx.operators
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import deepcopy_to_fake_tensor, get_fake_value, get_real_value
from torch._dynamo.variables.base import VariableTracker
from torch._dynamo.variables.builtin import BuiltinVariable
@ -504,8 +502,6 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
return OutDtypeHigherOrderVariable(value, source, **kwargs)
elif value is torch._functorch.eager_transforms.grad_impl:
return FunctorchGradHigherOrderVariable(value, source, **kwargs)
elif value is torch._functorch.vmap.vmap_impl:
return FunctorchVmapHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "wrap":
return WrapHigherOrderVariable(value, source, **kwargs)
elif value.__name__ in (
@ -1031,165 +1027,17 @@ class FunctorchGradHigherOrderVariable(TorchHigherOrderOperatorVariable):
return TupleVariable([TupleVariable(items), aux])
class FunctorchVmapHigherOrderVariable(TorchHigherOrderOperatorVariable):
class FunctorchVmapHigherOrderVariable(UserFunctionVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ConstantVariable, TensorVariable
from .builder import wrap_fx_proxy
if not torch._dynamo.config.capture_func_transforms:
unimplemented(
"torch.func.vmap capture is disabled, "
"it can be turned on by setting "
"`torch._dynamo.config.capture_func_transforms=True`"
)
# unpack args
fn = args[0]
in_dims = args[1]
out_dims = args[2]
randomness = args[3]
chunk_size = args[4]
batch_input_args = args[5:]
if not isinstance(in_dims, (ConstantVariable, TupleVariable)):
unimplemented("torch.func.vmap: in_dims is not an int or tuple variable.")
if not isinstance(out_dims, (ConstantVariable, TupleVariable)):
unimplemented("torch.func.vmap: out_dims is not an int or tuple variable.")
if len(kwargs) > 0:
unimplemented(
"NYI - torch.func.vmap: kwargs arguments are currently unsupported."
)
if chunk_size.value is not None:
unimplemented(
"NYI - torch.func.vmap is not implemented when chunk_size is passed"
)
# Trace into tree_flatten with the list of batch_input_args.
flat_args, arg_spec = _make_inlined(tx, pytree.tree_flatten)(
ListVariable(batch_input_args)
).unpack_var_sequence(tx)
# Transform in_dims into a list if it's not an integer literal.
in_dims_v = (
in_dims
if isinstance(in_dims.as_python_constant(), int)
else BuiltinVariable(list).call_function(tx, [in_dims], {})
)
# Trace into _broadcast_to_and_flatten with the transformed in_dims.
broadcasted_in_dims = _make_inlined(tx, pytree._broadcast_to_and_flatten)(
in_dims_v, arg_spec
)
# We want to pass unbatched input to speculate subgraph.
# So we loop through the inputs and select only one sample
# from the batch.
unbatched_input_args = []
for arg, in_dim in zip(
flat_args.unpack_var_sequence(tx),
broadcasted_in_dims.unpack_var_sequence(tx),
):
if in_dim is not None:
assert isinstance(arg, TensorVariable)
unbatched_arg = arg.call_method(
tx, "select", [in_dim, ConstantVariable.create(0)], {}
)
unbatched_input_args.append(unbatched_arg)
else:
unbatched_input_args.append(arg)
# Ban ops like `stride`, `storage_offset` in the traced functions.
# NOTE: We are conservatively banning more ops (vmap should be able
# to handle a few of them).
with tx.strict_translation_mode():
# trace through the function with unbatched inputs.
_, body_graph, body_lifted_freevars = speculate_subgraph(
tx,
fn,
# Returns a ListVariable, since that's where we started flattening.
# However, we really want to pass the inner Python list as argument.
_make_inlined(tx, pytree.tree_unflatten)(
ListVariable(unbatched_input_args), arg_spec
).unpack_var_sequence(tx),
{},
"torch.vmap",
source_target=self.value,
set_subgraph_inputs="manual",
)
body_name = add_subgraph(
tx,
self.source,
"vmap_body",
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
)
body_node = make_attr(tx, body_name)
# body_lifted_variable should not be treated as batched.
# So here we update `in_dims` to reflect that.
# NOTE: updated_in_dims is flat list, it is ok for now
# as speculate_subgraph does not supports functions with non-Tensor args.
# (so we graph-break above)
updated_in_dims = TupleVariable(
broadcasted_in_dims.unpack_var_sequence(tx)
+ [
ConstantVariable.create(None),
]
* len(body_lifted_freevars)
)
vmap_proxy_args = (
body_node,
*(arg.as_proxy() for arg in (updated_in_dims, out_dims, randomness)),
)
# vmap_proxy corresponds to `vmap_proxy = vmap(fn, *vmap_args, **vmap_kwargs)`
vmap_proxy = tx.output.create_proxy(
"call_function",
torch.func.vmap,
args=tuple(vmap_proxy_args),
kwargs={},
name="vmap_proxy",
)
proxy_batched_fn_args = tuple(
arg.as_proxy() for arg in batch_input_args
) + tuple(body_lifted_freevars)
# We compute the example_value by actually calling
# `vmap` with FakeTensors.
fake_batched_fn_args = itertools.chain(
(get_fake_value(arg.as_proxy().node, tx) for arg in batch_input_args),
(get_fake_value(arg.node, tx) for arg in body_lifted_freevars),
)
actual_in_dims = tuple(
pytree.tree_map(lambda x: x.value, updated_in_dims.items)
)
# NOTE: `body_graph` might have operators which
# will create new tensors. So it is required
# that we run `vmap` under FakeMode.
with tx.fake_mode, enable_python_dispatcher():
example_value = torch._functorch.vmap.vmap_impl(
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
actual_in_dims,
out_dims.as_python_constant(),
randomness.value,
chunk_size.value,
*fake_batched_fn_args,
)
# proxy corresponds to `call = vmap_proxy(*batched_fn_args, **batched_fn_kwargs)`
proxy = vmap_proxy(*proxy_batched_fn_args)
return wrap_fx_proxy(
tx=tx,
proxy=proxy,
example_value=example_value,
)
return super().call_function(tx, args, kwargs)
class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):

View File

@ -27,6 +27,7 @@ from ..guards import GuardBuilder
from ..utils import (
check_constant_args,
check_unspec_python_args,
guard_if_dyn,
has_torch_function,
product,
proxy_args_kwargs,
@ -140,7 +141,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import GradModeVariable, InferenceModeVariable, StreamVariable
from . import (
GradModeVariable,
InferenceModeVariable,
StreamVariable,
VmapIncrementNestingCtxManagerVariable,
)
if self.value is torch.no_grad:
if len(args) == 1 and isinstance(
@ -193,6 +199,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
elif self.value is torch._C.DisableTorchFunctionSubclass:
assert not (args or kwargs)
return TorchFunctionDisableVariable.create(tx)
elif self.value is torch._functorch.vmap.vmap_increment_nesting:
assert len(args) == 2
return VmapIncrementNestingCtxManagerVariable.create(
tx,
[guard_if_dyn(x) for x in args],
)
class TorchInGraphFunctionVariable(BaseTorchVariable):
@ -238,10 +250,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
):
tx.mark_inconsistent_side_effects()
return ConstantVariable.create(tracing_state_functions[self.value])
elif self.value in (
torch._functorch.vmap.vmap_impl,
torch._functorch.eager_transforms.grad_impl,
):
elif self.value in (torch._functorch.eager_transforms.grad_impl,):
return TorchHigherOrderOperatorVariable.make(
self.value,
source=self.source,

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import torch
import contextlib
import functools
import threading
from torch import Tensor
@ -274,17 +275,10 @@ def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):
return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
args_spec, out_dims, randomness, **kwargs)
from torch._dynamo import disable
# remove @disable once #114306 is fixed
@disable
def wrapper():
# If chunk_size is not specified.
return _flat_vmap(
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
)
return wrapper()
# If chunk_size is not specified.
return _flat_vmap(
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
)
def get_chunk_sizes(total_elems, chunk_size):
n_chunks = n_chunks = total_elems // chunk_size
@ -390,15 +384,22 @@ def _check_randomness_arg(randomness):
raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}")
@contextlib.contextmanager
def vmap_increment_nesting(batch_size, randomness):
try:
vmap_level = _vmap_increment_nesting(batch_size, randomness)
yield vmap_level
finally:
_vmap_decrement_nesting()
@doesnt_support_saved_tensors_hooks
def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs):
vmap_level = _vmap_increment_nesting(batch_size, randomness)
try:
with vmap_increment_nesting(batch_size, randomness) as vmap_level:
batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
batched_outputs = func(*batched_inputs, **kwargs)
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
finally:
_vmap_decrement_nesting()
# `restore_vmap` is a private helper function. It is vmap but has the following
@ -424,13 +425,10 @@ def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, r
@doesnt_support_saved_tensors_hooks
def restore_vmap(func, in_dims, batch_size, randomness):
def inner(*args, **kwargs):
vmap_level = _vmap_increment_nesting(batch_size, randomness)
try:
with vmap_increment_nesting(batch_size, randomness) as vmap_level:
batched_inputs = wrap_batched(args, in_dims, vmap_level)
batched_outputs = func(*batched_inputs, **kwargs)
return unwrap_batched(batched_outputs, vmap_level)
finally:
_vmap_decrement_nesting()
return inner

View File

@ -370,6 +370,15 @@ static int64_t currentLevel() {
return current_level;
}
static c10::optional<int64_t> maybe_current_level() {
auto maybe_layer = maybeCurrentDynamicLayer();
if (maybe_layer.has_value()) {
int current_level = maybe_layer->layerId();
return current_level;
}
return nullopt;
}
static void tls_set_vmap_excluded(bool excluded) {
c10::impl::tls_set_dispatch_key_excluded(
c10::DispatchKey::FuncTorchBatched, excluded);
@ -474,6 +483,7 @@ void initFuncTorchBindings(PyObject* module) {
m.def("get_unwrapped", &get_unwrapped);
m.def("maybe_get_level", &maybe_get_level);
m.def("maybe_get_bdim", &maybe_get_bdim);
m.def("maybe_current_level", &maybe_current_level);
m.def("current_level", &currentLevel);
m.def("tls_set_vmap_excluded", &tls_set_vmap_excluded);
m.def("_set_dynamic_layer_keys_included", &_set_dynamic_layer_keys_included);

View File

@ -2335,6 +2335,50 @@ dynamo_expected_failures = {
"TestControlFlowTraced.test_map_functionalized", # functorch/test_control_flow
"TestControlFlowTraced.test_nested_map_cond_symbolic", # functorch/test_control_flow
"TestControlFlowTraced.test_nested_map_cond_real", # functorch/test_control_flow
"TestJacCPU.test_against_reference_correctness_different_devices_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_against_reference_default_arg_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_against_reference_multi_input_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_against_reference_multi_input_multi_output_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_against_reference_simple_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_against_reference_unrelated_outputs_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_against_reference_zero_dim_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_argnums_defaults_to_zero_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_aux_pytree_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_chunk_jacrev_composition__preallocate_and_copy_False_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_chunk_jacrev_composition__preallocate_and_copy_True_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_dimensionality_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_empty_output_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_hessian_simple_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_inplace_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_jac_with_non_tensor_args_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_multiple_inputs_outputs_pytree_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_multiple_inputs_pytree_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_multiple_outputs_multiple_argnums_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_multiple_outputs_single_argnums_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_outputs_can_any_pytree_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_unrelated_input_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestJacCPU.test_unrelated_output_jacfwd_cpu", # functorch/test_eager_transforms.py
"TestVmapJvpInplaceViewCPU.test_all_dual_base_inplace_cpu", # functorch/test_eager_transforms.py
"TestVmapJvpInplaceViewCPU.test_all_dual_base_view_inplace_cpu", # functorch/test_eager_transforms.py
"TestVmapJvpInplaceViewCPU.test_all_dual_no_view_cpu", # functorch/test_eager_transforms.py
"TestVmapJvpInplaceViewCPU.test_right_dual_base_prop_cpu", # functorch/test_eager_transforms.py
"TestVmapJvpInplaceViewCPU.test_right_dual_view_prop_cpu", # functorch/test_eager_transforms.py
"TestHessianCPU.test_hessian_vectorize_correctness_multi_input_cpu", # functorch/test_eager_transforms.py
"TestHessianCPU.test_hessian_vectorize_correctness_simple_cpu", # functorch/test_eager_transforms.py
"TestHessianCPU.test_hessian_vectorize_correctness_unrelated_outputs_cpu", # functorch/test_eager_transforms.py
"TestHessianCPU.test_jacfwd_different_levels_cpu", # functorch/test_eager_transforms.py
"TestExamplesCorrectnessCPU.test_ensemble_regression_mechanism_functional_call_cpu", # functorch/test_eager_transforms.py
"TestExamplesCorrectnessCPU.test_ensemble_regression_mechanism_make_functional_cpu", # functorch/test_eager_transforms.py
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_AlphaDropout_mechanism_functional_call_cpu", # functorch/test_eager_transforms.py
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_AlphaDropout_mechanism_make_functional_cpu", # functorch/test_eager_transforms.py
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_Dropout_mechanism_functional_call_cpu", # functorch/test_eager_transforms.py
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_Dropout_mechanism_make_functional_cpu", # functorch/test_eager_transforms.py
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_FeatureAlphaDropout_mechanism_functional_call_cpu", # functorch/test_eager_transforms.py
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_FeatureAlphaDropout_mechanism_make_functional_cpu", # functorch/test_eager_transforms.py
"TestHigherOrderOperatorInteractionCPU.test_vmap_grad_sum_cpu", # functorch/test_eager_transforms.py
"TestFunctionalizeCPU.test_multioutput_view_cpu", # functorch/test_eager_transforms.py
"TestFunctionalizeCPU.test_simple_view_cpu", # functorch/test_eager_transforms.py
"TestFunctionalizeCPU.test_vmap_functionalize_jvp_cpu", # functorch/test_eager_transforms.py
"TestMetaKernel.test_addmm_invalid_dtype", # lazy/test_meta_kernel
"TestVerifyCorrectness.test_incorrect_verify_true", # dynamo/test_verify_correctness
"TestVerifyCorrectness.test_torchscript", # dynamo/test_verify_correctness
@ -2579,7 +2623,11 @@ dynamo_expected_failures = {
"FuncTorchHigherOrderOpTests.test_vmap_free_const", # dynamo/test_higher_order_ops
"FuncTorchHigherOrderOpTests.test_vmap_multiple_invocation_in_dims", # dynamo/test_higher_order_ops
"FuncTorchHigherOrderOpTests.test_grad", # dynamo/test_higher_order_ops
"FuncTorchHigherOrderOpTests.test_vmap_illegal_op_graph_break", # dynamo/test_higher_order_ops
"FuncTorchHigherOrderOpTests.test_vmap_with_conditional_graph_break", # dynamo/test_higher_order_ops
"FuncTorchHigherOrderOpTests.test_vmap_with_graph_break", # dynamo/test_higher_order_ops
"FuncTorchHigherOrderOpTests.test_vmap_with_graph_break_2", # dynamo/test_higher_order_ops
"FuncTorchHigherOrderOpTests.test_vmap_with_graph_break_lambda", # dynamo/test_higher_order_ops
"FuncTorchHigherOrderOpTests.test_vmap_previous_illegal_op_no_graph_break", # dynamo/test_higher_order_ops
"HigherOrderOpTests.test_cond_pytree_operands", # dynamo/test_higher_order_ops
"HigherOrderOpTests.test_cond_branches_no_arguments_no_closure", # dynamo/test_higher_order_ops
"FuncTorchHigherOrderOpTests.test_vmap_side_effects", # dynamo/test_higher_order_ops
@ -2775,6 +2823,8 @@ dynamo_expected_failures = {
"TestVmapOperatorsLegacy.test_unbind", # test_legacy_vmap
"TestVmapAPILegacy.test_non_default_in_dims_out_dims", # test_legacy_vmap
"TestVmapOperatorsLegacy.test_T_numpy", # test_legacy_vmap
"TestNamedTensor.test_expand", # test_namedtensor
"TestNamedTensor.test_masked_fill", # test_namedtensor
"TestNamedTensor.test_addmv", # test_namedtensor
"TestNamedTensor.test_cummax_cummin", # test_namedtensor
"TestNamedTensor.test_no_jit_script_support", # test_namedtensor

View File

@ -928,7 +928,7 @@ def tree_map_(
"""
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
tuple(map(func, *flat_args)) # consume and exhaust the iterable
return tree