mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enhance torch.vmap support from inside torch.compile (#116050)
This work rewrites vmap support in torch.compile by inlining most of the frames into the existing FX graph. It also unlocks to PyTorch to support features that were previously missing, such as keyword args. Fixes: https://github.com/pytorch/pytorch/issues/114306 Pull Request resolved: https://github.com/pytorch/pytorch/pull/116050 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
b2a3d6ba0d
commit
80cf0ce153
2
.flake8
2
.flake8
@ -34,6 +34,8 @@ per-file-ignores =
|
|||||||
torch/utils/cpp_extension.py: B950
|
torch/utils/cpp_extension.py: B950
|
||||||
torchgen/api/types/__init__.py: F401,F403
|
torchgen/api/types/__init__.py: F401,F403
|
||||||
torchgen/executorch/api/types/__init__.py: F401,F403
|
torchgen/executorch/api/types/__init__.py: F401,F403
|
||||||
|
test/dynamo/test_higher_order_ops.py: B950
|
||||||
|
torch/testing/_internal/dynamo_test_failures.py: B950
|
||||||
optional-ascii-coding = True
|
optional-ascii-coding = True
|
||||||
exclude =
|
exclude =
|
||||||
./.git,
|
./.git,
|
||||||
|
@ -4,6 +4,7 @@ import functools
|
|||||||
import pprint
|
import pprint
|
||||||
import re
|
import re
|
||||||
import unittest
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
import functorch.experimental.control_flow as control_flow
|
import functorch.experimental.control_flow as control_flow
|
||||||
|
|
||||||
@ -24,7 +25,12 @@ from torch._dynamo.testing import (
|
|||||||
)
|
)
|
||||||
from torch._dynamo.utils import counters, ifdynstaticdefault
|
from torch._dynamo.utils import counters, ifdynstaticdefault
|
||||||
from torch._higher_order_ops.wrap import wrap
|
from torch._higher_order_ops.wrap import wrap
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
TEST_WITH_TORCHDYNAMO,
|
||||||
|
xfailIfTorchDynamo,
|
||||||
|
)
|
||||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||||
|
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||||
|
|
||||||
|
|
||||||
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
|
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
|
||||||
@ -2272,13 +2278,13 @@ class GraphModule(torch.nn.Module):
|
|||||||
x = torch.randn(3, 3, 3, 3)
|
x = torch.randn(3, 3, 3, 3)
|
||||||
fn(x)
|
fn(x)
|
||||||
gm = backend.graphs[0]
|
gm = backend.graphs[0]
|
||||||
actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sum_2", "add"})
|
actual_stack = self._get_source_fn_stack(
|
||||||
|
gm, {"sum_1", "sum_2", "batched_output"}
|
||||||
|
)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
pprint.pformat(actual_stack),
|
pprint.pformat(actual_stack),
|
||||||
"""\
|
"""\
|
||||||
{'add': ['vmap_impl', 'vmap_impl', 'add'],
|
{'batched_output': ['add'], 'sum_1': ['sum_1'], 'sum_2': ['sum_2']}""",
|
||||||
'sum_1': ['vmap_impl', 'vmap_impl', 'sum_1'],
|
|
||||||
'sum_2': ['vmap_impl', 'vmap_impl', 'sum_2']}""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_cond_pytree_operands(self):
|
def test_cond_pytree_operands(self):
|
||||||
@ -2370,7 +2376,82 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
|
|||||||
torch.compile(fn, backend="eager")(pred, pytree_in)
|
torch.compile(fn, backend="eager")(pred, pytree_in)
|
||||||
|
|
||||||
|
|
||||||
|
class HigherOrderOpVmapGuardTests(LoggingTestCase):
|
||||||
|
@config.patch(capture_func_transforms=True)
|
||||||
|
@make_logging_test(recompiles=True)
|
||||||
|
def test_vmap_guard_ok(self, records):
|
||||||
|
@torch.compile(backend="eager")
|
||||||
|
def fn(x):
|
||||||
|
return torch.vmap(lambda x: x.sin())(x)
|
||||||
|
|
||||||
|
x = torch.randn(3, 3, 4, 5)
|
||||||
|
y = fn(x)
|
||||||
|
# sanity check
|
||||||
|
self.assertEqual(len(records), 0)
|
||||||
|
self.assertEqual(x.sin(), y)
|
||||||
|
|
||||||
|
# Calling the same function again won't have any effect on guards
|
||||||
|
z = fn(x)
|
||||||
|
self.assertEqual(len(records), 0)
|
||||||
|
self.assertEqual(x.sin(), z)
|
||||||
|
|
||||||
|
# calling with a different object will also not affect guards
|
||||||
|
w = fn(z)
|
||||||
|
self.assertEqual(len(records), 0)
|
||||||
|
self.assertEqual(z.sin(), w)
|
||||||
|
|
||||||
|
@xfailIfTorchDynamo
|
||||||
|
@config.patch(capture_func_transforms=True)
|
||||||
|
@make_logging_test(recompiles=True)
|
||||||
|
def test_vmap_guard_fail(self, records):
|
||||||
|
@torch.compile(backend="eager")
|
||||||
|
def fn(x):
|
||||||
|
return torch.vmap(lambda x: x.sin())(x)
|
||||||
|
|
||||||
|
x = torch.zeros(3, 3, 4, 5)
|
||||||
|
y = torch.vmap(fn)(x)
|
||||||
|
self.assertEqual(x.sin(), y)
|
||||||
|
self.assertEqual(len(records), 0)
|
||||||
|
|
||||||
|
# call vmap(vmap(fn))(x) should retrigger compilation as
|
||||||
|
# _functorch.current_level() is not the same
|
||||||
|
x = torch.zeros(3, 3, 3, 4, 5)
|
||||||
|
y = torch.vmap(torch.vmap(fn))(x)
|
||||||
|
self.assertEqual(x.sin(), y)
|
||||||
|
self.assertGreater(len(records), 0)
|
||||||
|
record = self.getRecord(records, "maybe_current_level()")
|
||||||
|
self.assertIn(
|
||||||
|
"""\
|
||||||
|
triggered by the following guard failure(s):
|
||||||
|
- torch._C._functorch.maybe_current_level() is None # with vmap_increment_nesting(batch_size, randomness) as vmap_level: # _functorch/vmap.py:399 in _flat_vmap""",
|
||||||
|
record.getMessage(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||||
|
def tearDown(self):
|
||||||
|
# Ensure that in the case of a test failure, the next test won't fail
|
||||||
|
# because of a previous call to _vmap_increment_nesting that wasn't undone
|
||||||
|
# i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1
|
||||||
|
# and the call to increment nesting is not undone
|
||||||
|
if not TEST_WITH_TORCHDYNAMO:
|
||||||
|
return
|
||||||
|
|
||||||
|
warn = False
|
||||||
|
while ci := torch._C._functorch.peek_interpreter_stack():
|
||||||
|
if ci.key() == torch._C._functorch.TransformType.Vmap:
|
||||||
|
warn = True
|
||||||
|
torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if warn:
|
||||||
|
msg = (
|
||||||
|
"Interpreter stack is not empty. Test should have called "
|
||||||
|
"'torch._C._functorch._vmap_decrement_nesting()'"
|
||||||
|
)
|
||||||
|
warnings.warn(msg)
|
||||||
|
|
||||||
def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0):
|
def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0):
|
||||||
backend = EagerAndRecordGraphs()
|
backend = EagerAndRecordGraphs()
|
||||||
actual = fn(*inputs)
|
actual = fn(*inputs)
|
||||||
@ -2946,6 +3027,112 @@ class GraphModule(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
@config.patch(capture_func_transforms=True)
|
||||||
|
def test_vmap_get_wrapped(self):
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
def g(x):
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
@torch.compile(backend="aot_eager", fullgraph=True)
|
||||||
|
def fn():
|
||||||
|
return torch.vmap(g)
|
||||||
|
|
||||||
|
x = torch.randn(3, 4)
|
||||||
|
expected = torch.vmap(g)(x)
|
||||||
|
wrapper = fn()
|
||||||
|
got = wrapper(x)
|
||||||
|
self.assertEqual(expected, got)
|
||||||
|
|
||||||
|
@config.patch(capture_func_transforms=True)
|
||||||
|
def test_vmap_with_conditional_graph_break(self):
|
||||||
|
def g(x):
|
||||||
|
if len(x.shape) < 2:
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
return x.sin()
|
||||||
|
else:
|
||||||
|
return x.cos()
|
||||||
|
|
||||||
|
@torch.compile(backend="aot_eager")
|
||||||
|
def fn(x):
|
||||||
|
return torch.vmap(g)(x)
|
||||||
|
|
||||||
|
counters.clear()
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
expected = x.sin()
|
||||||
|
got = fn(x)
|
||||||
|
self.assertEqual(expected, got)
|
||||||
|
self.assertEqual(len(counters["graph_break"]), 1)
|
||||||
|
|
||||||
|
counters.clear()
|
||||||
|
y = torch.randn(2, 3, 4)
|
||||||
|
expected = y.cos()
|
||||||
|
got = fn(y)
|
||||||
|
self.assertEqual(expected, got)
|
||||||
|
self.assertEqual(len(counters["graph_break"]), 0)
|
||||||
|
|
||||||
|
@config.patch(capture_func_transforms=True)
|
||||||
|
def test_vmap_with_graph_break(self):
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
def g(x):
|
||||||
|
y = x.cos()
|
||||||
|
print("hi")
|
||||||
|
return y.sin()
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
return torch.vmap(g)(x)
|
||||||
|
|
||||||
|
x = torch.randn(3, 4)
|
||||||
|
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
|
||||||
|
expected = fn(x)
|
||||||
|
got = opt(x)
|
||||||
|
self.assertEqual(len(counters["graph_break"]), 1)
|
||||||
|
self.assertEqual(expected, got)
|
||||||
|
|
||||||
|
@config.patch(capture_func_transforms=True)
|
||||||
|
def test_vmap_with_graph_break_2(self):
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
def cos(x):
|
||||||
|
print("cos")
|
||||||
|
return x.cos()
|
||||||
|
|
||||||
|
def sin(x):
|
||||||
|
print("sin")
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
def g(x):
|
||||||
|
y = cos(x)
|
||||||
|
return sin(y)
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
return torch.vmap(g, randomness="same")(x)
|
||||||
|
|
||||||
|
x = torch.randn(3, 4)
|
||||||
|
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
|
||||||
|
expected = fn(x)
|
||||||
|
got = opt(x)
|
||||||
|
self.assertEqual(len(counters["graph_break"]), 1)
|
||||||
|
self.assertEqual(expected, got)
|
||||||
|
|
||||||
|
def test_vmap_with_graph_break_lambda(self):
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
def sin(x):
|
||||||
|
print("sin")
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
return torch.vmap(lambda x: sin(x))(x)
|
||||||
|
|
||||||
|
x = torch.randn(3, 4)
|
||||||
|
opt = torch.compile(fn, backend="aot_eager", fullgraph=False)
|
||||||
|
expected = fn(x)
|
||||||
|
got = opt(x)
|
||||||
|
self.assertEqual(len(counters["graph_break"]), 1)
|
||||||
|
self.assertEqual(expected, got)
|
||||||
|
|
||||||
@config.patch(capture_func_transforms=True)
|
@config.patch(capture_func_transforms=True)
|
||||||
def test_vmap(self):
|
def test_vmap(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
@ -2964,22 +3151,24 @@ class GraphModule(torch.nn.Module):
|
|||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor):
|
||||||
child = L_x_
|
arg = L_x_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
|
|
||||||
select = child.select(0, 0)
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
vmap_body_0 = self.vmap_body_0
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_0, (0,), 0, 'error'); vmap_body_0 = None
|
|
||||||
call = vmap_proxy.__call__(child); vmap_proxy = child = None
|
|
||||||
return (call,)
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
|
||||||
def forward(self, select):
|
|
||||||
sum_1 = select.sum(0)
|
sum_1 = _add_batch_dim.sum(0)
|
||||||
sum_2 = select.sum(1); select = None
|
sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||||
add = sum_1 + sum_2; sum_1 = sum_2 = None
|
batched_output = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||||
return add
|
|
||||||
|
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||||
|
return (_remove_batch_dim,)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3003,23 +3192,25 @@ class GraphModule(torch.nn.Module):
|
|||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor):
|
||||||
child = L_x_
|
arg = L_x_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
|
|
||||||
select = child.select(0, 0)
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
vmap_body_0 = self.vmap_body_0
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_0, (0,), 0, 'error'); vmap_body_0 = None
|
|
||||||
call = vmap_proxy.__call__(child); vmap_proxy = child = None
|
|
||||||
return (call,)
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
|
||||||
def forward(self, select):
|
|
||||||
sum_1 = select.sum(0)
|
sum_1 = _add_batch_dim.sum(0)
|
||||||
sum_2 = select.sum(1); select = None
|
sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||||
add = sum_1 + sum_2; sum_1 = sum_2 = None
|
add = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||||
add_1 = add + 3; add = None
|
batched_output = add + 3; add = None
|
||||||
return add_1
|
|
||||||
|
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||||
|
return (_remove_batch_dim,)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3043,24 +3234,26 @@ class GraphModule(torch.nn.Module):
|
|||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||||
child = L_x_
|
arg = L_x_
|
||||||
l_y_ = L_y_
|
l_y_ = L_y_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
|
|
||||||
select = child.select(0, 0)
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
vmap_body_0 = self.vmap_body_0
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_0, (0, None), 0, 'error'); vmap_body_0 = None
|
|
||||||
call = vmap_proxy.__call__(child, l_y_); vmap_proxy = child = l_y_ = None
|
|
||||||
return (call,)
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
|
||||||
def forward(self, select, l_y_):
|
|
||||||
sum_1 = select.sum(0)
|
sum_1 = _add_batch_dim.sum(0)
|
||||||
sum_2 = select.sum(1); select = None
|
sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||||
add = sum_1 + sum_2; sum_1 = sum_2 = None
|
add = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||||
add_1 = add + l_y_; add = l_y_ = None
|
batched_output = add + l_y_; add = l_y_ = None
|
||||||
return add_1
|
|
||||||
|
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||||
|
return (_remove_batch_dim,)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3085,25 +3278,27 @@ class GraphModule(torch.nn.Module):
|
|||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||||
child = L_x_
|
arg = L_x_
|
||||||
child_1 = L_y_
|
arg_3 = L_y_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
|
|
||||||
select = child.select(0, 0)
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
select_1 = child_1.select(1, 0)
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
|
||||||
vmap_body_0 = self.vmap_body_0
|
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_0, (0, 1), 0, 'error'); vmap_body_0 = None
|
|
||||||
call = vmap_proxy.__call__(child, child_1); vmap_proxy = child = child_1 = None
|
|
||||||
return (call,)
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
|
||||||
def forward(self, select, select_1):
|
_add_batch_dim_1 = torch._C._functorch._add_batch_dim(arg_3, 1, 1); arg_3 = None
|
||||||
sum_1 = select.sum(0)
|
|
||||||
sum_2 = select.sum(1); select = None
|
sum_1 = _add_batch_dim.sum(0)
|
||||||
add = sum_1 + sum_2; sum_1 = sum_2 = None
|
sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||||
add_1 = add + select_1; add = select_1 = None
|
add = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||||
return add_1
|
batched_output = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
|
||||||
|
|
||||||
|
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||||
|
return (_remove_batch_dim,)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3130,25 +3325,27 @@ class GraphModule(torch.nn.Module):
|
|||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||||
child = L_x_
|
arg = L_x_
|
||||||
child_1 = L_y_
|
arg_3 = L_y_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
|
|
||||||
select = child.select(0, 0)
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
select_1 = child_1.select(1, 0)
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
|
||||||
vmap_body_0 = self.vmap_body_0
|
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_0, (0, 1), 0, 'error'); vmap_body_0 = None
|
|
||||||
call = vmap_proxy.__call__(child, child_1); vmap_proxy = child = child_1 = None
|
|
||||||
return (call,)
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
|
||||||
def forward(self, select, select_1):
|
_add_batch_dim_1 = torch._C._functorch._add_batch_dim(arg_3, 1, 1); arg_3 = None
|
||||||
sum_1 = select.sum(0)
|
|
||||||
sum_2 = select.sum(1); select = None
|
sum_1 = _add_batch_dim.sum(0)
|
||||||
add = sum_1 + sum_2; sum_1 = sum_2 = None
|
sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||||
add_1 = add + select_1; add = select_1 = None
|
add = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||||
return add_1
|
batched_output = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
|
||||||
|
|
||||||
|
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 3, 0); batched_output = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||||
|
return (_remove_batch_dim,)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3171,32 +3368,37 @@ class GraphModule(torch.nn.Module):
|
|||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||||
child = L_x_
|
arg = L_x_
|
||||||
child_1 = L_y_
|
arg_3 = L_y_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
_check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error')
|
|
||||||
|
|
||||||
select = child.select(0, 0)
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
select_1 = child_1.select(0, 0)
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error')
|
||||||
vmap_body_1 = self.vmap_body_1
|
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_1, (0, 0), 0, 'error'); vmap_body_1 = None
|
|
||||||
call = vmap_proxy.__call__(child, child_1); vmap_proxy = child = child_1 = None
|
|
||||||
return (call,)
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
arg_8 = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
|
||||||
def forward(self, select, select_1):
|
arg_9 = torch._C._functorch._add_batch_dim(arg_3, 0, 1); arg_3 = None
|
||||||
select_2 = select.select(1, 0)
|
|
||||||
select_3 = select_1.select(1, 0)
|
|
||||||
vmap_body_0 = self.vmap_body_0
|
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_0, (1, 1), 0, 'error'); vmap_body_0 = None
|
|
||||||
call = vmap_proxy.__call__(select, select_1); vmap_proxy = select = select_1 = None
|
|
||||||
return call
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
def forward(self, select_2, select_3):
|
|
||||||
add = select_2 + select_3; select_2 = select_3 = None
|
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
return add
|
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error')
|
||||||
|
|
||||||
|
_add_batch_dim_2 = torch._C._functorch._add_batch_dim(arg_8, 1, 2); arg_8 = None
|
||||||
|
_add_batch_dim_3 = torch._C._functorch._add_batch_dim(arg_9, 1, 2); arg_9 = None
|
||||||
|
|
||||||
|
batched_output = _add_batch_dim_2 + _add_batch_dim_3; _add_batch_dim_2 = _add_batch_dim_3 = None
|
||||||
|
|
||||||
|
batched_output_1 = torch._C._functorch._remove_batch_dim(batched_output, 2, 3, 0); batched_output = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
|
|
||||||
|
_remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 3, 0); batched_output_1 = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||||
|
return (_remove_batch_dim_1,)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3220,30 +3422,35 @@ class GraphModule(torch.nn.Module):
|
|||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
|
def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
|
||||||
child = L_y_
|
arg = L_y_
|
||||||
l_x_ = L_x_
|
l_x_ = L_x_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
_check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error')
|
|
||||||
|
|
||||||
select = child.select(0, 0)
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
vmap_body_1 = self.vmap_body_1
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error')
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_1, (0, None), 0, 'error'); vmap_body_1 = None
|
|
||||||
call = vmap_proxy.__call__(child, l_x_); vmap_proxy = child = l_x_ = None
|
|
||||||
return (call,)
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
arg_3 = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
|
||||||
def forward(self, select, l_x_):
|
|
||||||
select_1 = select.select(0, 0)
|
|
||||||
vmap_body_0 = self.vmap_body_0
|
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_0, (0, None), 0, 'error'); vmap_body_0 = None
|
|
||||||
call = vmap_proxy.__call__(select, l_x_); vmap_proxy = select = l_x_ = None
|
|
||||||
return call
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
def forward(self, select_1, l_x_):
|
|
||||||
mul = l_x_ * select_1; l_x_ = select_1 = None
|
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
return mul
|
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error')
|
||||||
|
|
||||||
|
_add_batch_dim_1 = torch._C._functorch._add_batch_dim(arg_3, 0, 2); arg_3 = None
|
||||||
|
|
||||||
|
batched_output = l_x_ * _add_batch_dim_1; l_x_ = _add_batch_dim_1 = None
|
||||||
|
|
||||||
|
batched_output_1 = torch._C._functorch._remove_batch_dim(batched_output, 2, 3, 0); batched_output = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
|
|
||||||
|
_remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 5, 0); batched_output_1 = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||||
|
return (_remove_batch_dim_1,)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3266,23 +3473,24 @@ class GraphModule(torch.nn.Module):
|
|||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor):
|
||||||
child = L_x_
|
arg = L_x_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
|
|
||||||
select = child.select(0, 0)
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
vmap_body_0 = self.vmap_body_0
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error')
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_0, (0,), 0, 'error'); vmap_body_0 = None
|
|
||||||
call = vmap_proxy.__call__(child); vmap_proxy = child = None
|
|
||||||
getitem = call[0]
|
|
||||||
getitem_1 = call[1]; call = None
|
|
||||||
return (getitem, getitem_1)
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
|
||||||
def forward(self, select):
|
|
||||||
sum_1 = select.sum(0)
|
batched_output = _add_batch_dim.sum(0)
|
||||||
sum_2 = select.sum(1); select = None
|
batched_output_1 = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||||
return (sum_1, sum_2)
|
|
||||||
|
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 2, 0); batched_output = None
|
||||||
|
_remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 2, 0); batched_output_1 = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||||
|
return (_remove_batch_dim, _remove_batch_dim_1)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3305,23 +3513,24 @@ class GraphModule(torch.nn.Module):
|
|||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor):
|
||||||
child = L_x_
|
arg = L_x_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
|
|
||||||
select = child.select(0, 0)
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
vmap_body_0 = self.vmap_body_0
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error')
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_0, (0,), (1, 0), 'error'); vmap_body_0 = None
|
|
||||||
call = vmap_proxy.__call__(child); vmap_proxy = child = None
|
|
||||||
getitem = call[0]
|
|
||||||
getitem_1 = call[1]; call = None
|
|
||||||
return (getitem, getitem_1)
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
|
||||||
def forward(self, select):
|
|
||||||
sum_1 = select.sum(0)
|
batched_output = _add_batch_dim.sum(0)
|
||||||
sum_2 = select.sum(1); select = None
|
batched_output_1 = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||||
return (sum_1, sum_2)
|
|
||||||
|
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 2, 1); batched_output = None
|
||||||
|
_remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 2, 0); batched_output_1 = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||||
|
return (_remove_batch_dim, _remove_batch_dim_1)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3345,23 +3554,24 @@ class GraphModule(torch.nn.Module):
|
|||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_x_ : torch.Tensor):
|
def forward(self, L_x_ : torch.Tensor):
|
||||||
child = L_x_
|
arg = L_x_
|
||||||
|
|
||||||
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
|
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions()
|
||||||
|
|
||||||
select = child.select(0, 0)
|
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||||
vmap_body_0 = self.vmap_body_0
|
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error')
|
||||||
vmap_proxy = torch.func.vmap(vmap_body_0, (0,), (1, 0), 'error'); vmap_body_0 = None
|
|
||||||
call = vmap_proxy.__call__(child); vmap_proxy = child = None
|
|
||||||
getitem = call[0]
|
|
||||||
getitem_1 = call[1]; call = None
|
|
||||||
return (getitem, getitem_1)
|
|
||||||
|
|
||||||
class GraphModule(torch.nn.Module):
|
_add_batch_dim = torch._C._functorch._add_batch_dim(arg, 0, 1); arg = None
|
||||||
def forward(self, select):
|
|
||||||
sum_1 = select.sum(0)
|
batched_output = _add_batch_dim.sum(0)
|
||||||
sum_2 = select.sum(1); select = None
|
batched_output_1 = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||||
return (sum_1, sum_2)
|
|
||||||
|
_remove_batch_dim = torch._C._functorch._remove_batch_dim(batched_output, 1, 2, 1); batched_output = None
|
||||||
|
_remove_batch_dim_1 = torch._C._functorch._remove_batch_dim(batched_output_1, 1, 2, 0); batched_output_1 = None
|
||||||
|
|
||||||
|
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting()
|
||||||
|
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||||
|
return (_remove_batch_dim, _remove_batch_dim_1)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3376,11 +3586,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
|
|
||||||
actual = fn(x, y)
|
actual = fn(x, y)
|
||||||
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
|
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
|
||||||
self.assertEqual(len(counters["graph_break"]), 1)
|
self.assertEqual(len(counters["graph_break"]), 0)
|
||||||
self.assertEqual(
|
|
||||||
dict(counters["graph_break"]),
|
|
||||||
{"NYI - torch.func.vmap: kwargs arguments are currently unsupported.": 2},
|
|
||||||
)
|
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
@config.patch(capture_func_transforms=True)
|
@config.patch(capture_func_transforms=True)
|
||||||
@ -3399,14 +3605,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
|
|
||||||
actual = fn(x, y)
|
actual = fn(x, y)
|
||||||
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
|
expected = torch.compile(fn, backend="aot_eager", fullgraph=False)(x, y)
|
||||||
self.assertEqual(len(counters["graph_break"]), 1)
|
self.assertEqual(len(counters["graph_break"]), 0)
|
||||||
assert_dict_matches_regex(
|
|
||||||
self,
|
|
||||||
dict(counters["graph_break"]),
|
|
||||||
{
|
|
||||||
".*torch.vmap with body that accepts non-Tensors as input": 2,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
@config.patch(capture_func_transforms=True)
|
@config.patch(capture_func_transforms=True)
|
||||||
@ -3426,14 +3625,48 @@ class GraphModule(torch.nn.Module):
|
|||||||
|
|
||||||
actual = wrapper_fn(x, y)
|
actual = wrapper_fn(x, y)
|
||||||
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
|
||||||
self.assertEqual(len(counters["graph_break"]), 1)
|
self.assertEqual(len(counters["graph_break"]), 0)
|
||||||
assert_dict_matches_regex(
|
self.assertEqual(actual, expected)
|
||||||
self,
|
self.assertEqual(some_list, [1, 1])
|
||||||
dict(counters["graph_break"]),
|
|
||||||
{
|
@unittest.expectedFailure
|
||||||
r".*HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)": 2
|
@config.patch(capture_func_transforms=True)
|
||||||
},
|
def test_vmap_side_effects_append_input(self):
|
||||||
)
|
counters.clear()
|
||||||
|
x = torch.ones(2, 3)
|
||||||
|
y = torch.randn(2, 3)
|
||||||
|
|
||||||
|
some_list = []
|
||||||
|
|
||||||
|
def f(x, y):
|
||||||
|
some_list.append(x)
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
def wrapper_fn(x, y):
|
||||||
|
return torch.func.vmap(f)(x, y)
|
||||||
|
|
||||||
|
actual = wrapper_fn(x, y)
|
||||||
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
|
||||||
|
self.assertEqual(len(counters["graph_break"]), 0)
|
||||||
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
@config.patch(capture_func_transforms=True)
|
||||||
|
def test_vmap_previous_illegal_op_no_graph_break(self):
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
# calling .stride() would previously graph break
|
||||||
|
def bad_fn(x):
|
||||||
|
y = x.view((4, 3))
|
||||||
|
y.stride()
|
||||||
|
return y
|
||||||
|
|
||||||
|
def wrapper_fn(x):
|
||||||
|
return torch.func.vmap(bad_fn)(x)
|
||||||
|
|
||||||
|
x = torch.randn(2, 3, 4)
|
||||||
|
actual = wrapper_fn(x)
|
||||||
|
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
||||||
|
self.assertEqual(len(counters["graph_break"]), 0)
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
@config.patch(capture_func_transforms=True)
|
@config.patch(capture_func_transforms=True)
|
||||||
@ -3461,28 +3694,6 @@ class GraphModule(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
@config.patch(capture_func_transforms=True)
|
|
||||||
def test_vmap_illegal_op_graph_break(self):
|
|
||||||
counters.clear()
|
|
||||||
|
|
||||||
def bad_fn(x):
|
|
||||||
x.stride()
|
|
||||||
return x
|
|
||||||
|
|
||||||
def wrapper_fn(x):
|
|
||||||
return torch.func.vmap(bad_fn)(x)
|
|
||||||
|
|
||||||
x = torch.randn(3, 3, 3)
|
|
||||||
actual = wrapper_fn(x)
|
|
||||||
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
|
||||||
self.assertEqual(len(counters["graph_break"]), 1)
|
|
||||||
assert_dict_matches_regex(
|
|
||||||
self,
|
|
||||||
dict(counters["graph_break"]),
|
|
||||||
{".*Illegal getattr invocation stride in strict mode": 2},
|
|
||||||
)
|
|
||||||
self.assertEqual(actual, expected)
|
|
||||||
|
|
||||||
@config.patch(capture_func_transforms=True)
|
@config.patch(capture_func_transforms=True)
|
||||||
def test_vmap_multiple_invocation_in_dims(self):
|
def test_vmap_multiple_invocation_in_dims(self):
|
||||||
counters.clear()
|
counters.clear()
|
||||||
@ -3498,7 +3709,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
self.assertEqual(cnt.frame_count, 3)
|
self.assertEqual(cnt.frame_count, 3)
|
||||||
self.assertEqual(cnt.op_count, 9)
|
self.assertEqual(cnt.op_count, 33)
|
||||||
|
|
||||||
@config.patch(capture_func_transforms=True)
|
@config.patch(capture_func_transforms=True)
|
||||||
def test_vmap_multiple_invocation_out_dims(self):
|
def test_vmap_multiple_invocation_out_dims(self):
|
||||||
@ -3515,7 +3726,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
actual = opt(x, 0), opt(x, 1), opt(x, 2)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
self.assertEqual(cnt.frame_count, 3)
|
self.assertEqual(cnt.frame_count, 3)
|
||||||
self.assertEqual(cnt.op_count, 9)
|
self.assertEqual(cnt.op_count, 30)
|
||||||
|
|
||||||
@config.patch(capture_func_transforms=True)
|
@config.patch(capture_func_transforms=True)
|
||||||
def test_vmap_new_tensor_in_body(self):
|
def test_vmap_new_tensor_in_body(self):
|
||||||
@ -3528,7 +3739,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
x = torch.randn(
|
x = torch.randn(
|
||||||
3,
|
3,
|
||||||
)
|
)
|
||||||
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=True)
|
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
|
||||||
expected = wrapper_fn(x)
|
expected = wrapper_fn(x)
|
||||||
actual = opt(x)
|
actual = opt(x)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
@ -3542,7 +3753,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
return torch.func.vmap(fn)(x)
|
return torch.func.vmap(fn)(x)
|
||||||
|
|
||||||
x = torch.randn(3)
|
x = torch.randn(3)
|
||||||
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=True)
|
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
|
||||||
expected = wrapper_fn(x)
|
expected = wrapper_fn(x)
|
||||||
actual = opt(x)
|
actual = opt(x)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
@ -3553,7 +3764,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
return torch.func.vmap(lambda t: torch.add(t, 0.5))(x)
|
return torch.func.vmap(lambda t: torch.add(t, 0.5))(x)
|
||||||
|
|
||||||
x = torch.randn(3)
|
x = torch.randn(3)
|
||||||
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=True)
|
opt = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=True)
|
||||||
expected = wrapper_fn(x)
|
expected = wrapper_fn(x)
|
||||||
actual = opt(x)
|
actual = opt(x)
|
||||||
self.assertEqual(expected, actual)
|
self.assertEqual(expected, actual)
|
||||||
|
@ -13,6 +13,7 @@ def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
|
|||||||
def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
|
def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
|
||||||
def maybe_get_bdim(tensor: Tensor) -> int: ...
|
def maybe_get_bdim(tensor: Tensor) -> int: ...
|
||||||
def maybe_get_level(tensor: Tensor) -> int: ...
|
def maybe_get_level(tensor: Tensor) -> int: ...
|
||||||
|
def maybe_current_level() -> Optional[int]: ...
|
||||||
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
|
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
|
||||||
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
||||||
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
||||||
|
@ -360,6 +360,15 @@ class GuardBuilder(GuardBuilderBase):
|
|||||||
|
|
||||||
self._produce_guard_code(guard, [code], provided_guarded_object=self.get(base))
|
self._produce_guard_code(guard, [code], provided_guarded_object=self.get(base))
|
||||||
|
|
||||||
|
def FUNCTORCH_CURRENT_LEVEL_MATCH(self, guard: Guard):
|
||||||
|
# Invalidate the graph if a call to vmap has been made prior to this
|
||||||
|
# This is super conservative as the interpreter stack may not contain
|
||||||
|
# vmap
|
||||||
|
code = [
|
||||||
|
"torch._C._functorch.maybe_current_level() is None",
|
||||||
|
]
|
||||||
|
self._produce_guard_code(guard, code)
|
||||||
|
|
||||||
def EQUALS_MATCH(self, guard: Guard):
|
def EQUALS_MATCH(self, guard: Guard):
|
||||||
ref = self.arg_ref(guard)
|
ref = self.arg_ref(guard)
|
||||||
val = self.get(guard.name)
|
val = self.get(guard.name)
|
||||||
|
@ -124,7 +124,6 @@ BUILTIN_SKIPLIST = (
|
|||||||
# third party libraries skiplist is defined by str, because users may not use these libraries.
|
# third party libraries skiplist is defined by str, because users may not use these libraries.
|
||||||
# we should use lazy import & skip in the future.
|
# we should use lazy import & skip in the future.
|
||||||
THIRDPARTY_SKIPLIST = (
|
THIRDPARTY_SKIPLIST = (
|
||||||
"functorch",
|
|
||||||
"fx2trt_oss",
|
"fx2trt_oss",
|
||||||
"networkx",
|
"networkx",
|
||||||
"numpy",
|
"numpy",
|
||||||
@ -201,6 +200,7 @@ MOD_INLINELIST = {
|
|||||||
"torch._dynamo._trace_wrapped_higher_order_op",
|
"torch._dynamo._trace_wrapped_higher_order_op",
|
||||||
"torch._dynamo.comptime",
|
"torch._dynamo.comptime",
|
||||||
"torch._dynamo.polyfill",
|
"torch._dynamo.polyfill",
|
||||||
|
"torch._functorch.vmap",
|
||||||
"torch._inductor.test_operators",
|
"torch._inductor.test_operators",
|
||||||
"torch.amp.autocast_mode",
|
"torch.amp.autocast_mode",
|
||||||
"torch.ao.nn",
|
"torch.ao.nn",
|
||||||
|
@ -2057,6 +2057,8 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||||||
speculation_log=speculation_log,
|
speculation_log=speculation_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._throw_if_in_vmap()
|
||||||
|
|
||||||
# as soon as we create the tracing context we should keep it active, so any calls
|
# as soon as we create the tracing context we should keep it active, so any calls
|
||||||
# into dynamo apis can rely on finding it
|
# into dynamo apis can rely on finding it
|
||||||
with tracing(self.output.tracing_context), self.set_current_tx():
|
with tracing(self.output.tracing_context), self.set_current_tx():
|
||||||
@ -2093,6 +2095,22 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||||||
if name in f_locals:
|
if name in f_locals:
|
||||||
self._freevars_ids[name] = id(f_locals[name])
|
self._freevars_ids[name] = id(f_locals[name])
|
||||||
|
|
||||||
|
def _throw_if_in_vmap(self):
|
||||||
|
# Fallback to eager in case of a graph break inside vmap
|
||||||
|
eager = torch._dynamo.lookup_backend("eager")
|
||||||
|
compiler_fn = inspect.getattr_static(
|
||||||
|
self.output.compiler_fn, "compiler_fn", self.output.compiler_fn
|
||||||
|
)
|
||||||
|
ci = torch._C._functorch.peek_interpreter_stack()
|
||||||
|
if (
|
||||||
|
ci is not None
|
||||||
|
and ci.key() == torch._C._functorch.TransformType.Vmap
|
||||||
|
and compiler_fn is not eager
|
||||||
|
):
|
||||||
|
# if it reaches here, it means Dynamo failed to inline vmap
|
||||||
|
msg = "torch.vmap(fn) requires the function to be inlined by dynamo"
|
||||||
|
unimplemented(msg)
|
||||||
|
|
||||||
def get_example_value(self, source: Source):
|
def get_example_value(self, source: Source):
|
||||||
if isinstance(source, LocalSource):
|
if isinstance(source, LocalSource):
|
||||||
return self.f_locals[source.local_name]
|
return self.f_locals[source.local_name]
|
||||||
|
@ -20,9 +20,11 @@ import torch
|
|||||||
from .utils import hashable, is_function, NP_SUPPORTED_MODULES
|
from .utils import hashable, is_function, NP_SUPPORTED_MODULES
|
||||||
|
|
||||||
from .variables import (
|
from .variables import (
|
||||||
|
FunctorchVmapHigherOrderVariable,
|
||||||
SkipFilesVariable,
|
SkipFilesVariable,
|
||||||
TorchCtxManagerClassVariable,
|
TorchCtxManagerClassVariable,
|
||||||
TorchInGraphFunctionVariable,
|
TorchInGraphFunctionVariable,
|
||||||
|
UserFunctionVariable,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .variables.base import VariableTracker
|
from .variables.base import VariableTracker
|
||||||
@ -132,6 +134,31 @@ manual_torch_name_rule_map = {
|
|||||||
"torch.resize_as_": SkipFilesVariable,
|
"torch.resize_as_": SkipFilesVariable,
|
||||||
"torch.resize_as_sparse_": SkipFilesVariable,
|
"torch.resize_as_sparse_": SkipFilesVariable,
|
||||||
"torch.get_default_device": TorchInGraphFunctionVariable,
|
"torch.get_default_device": TorchInGraphFunctionVariable,
|
||||||
|
# functorch
|
||||||
|
"torch._functorch.vmap._check_int_or_none": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._check_out_dims_is_int_or_int_pytree": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._check_randomness_arg": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._chunked_vmap": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._concat_chunked_outputs": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._create_batched_inputs": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._flat_vmap": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._flatten_chunks_output": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._get_chunked_inputs": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._get_name": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._maybe_remove_batch_dim": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._num_outputs": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._process_batched_inputs": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._unwrap_batched": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap._validate_and_get_batch_size": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap.doesnt_support_saved_tensors_hooks": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap.get_chunk_sizes": UserFunctionVariable,
|
||||||
|
# lazy_load_decompositions uses a lock that is not supported yet in dynamo
|
||||||
|
# "torch._functorch.vmap.lazy_load_decompositions": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap.restore_vmap": UserFunctionVariable,
|
||||||
|
"torch._functorch.apis.vmap": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap.unwrap_batched": UserFunctionVariable,
|
||||||
|
"torch._functorch.vmap.vmap_impl": FunctorchVmapHigherOrderVariable,
|
||||||
|
"torch._functorch.vmap.wrap_batched": UserFunctionVariable,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -140,6 +167,7 @@ torch_ctx_manager_classes = {
|
|||||||
k: TorchCtxManagerClassVariable
|
k: TorchCtxManagerClassVariable
|
||||||
for k in [
|
for k in [
|
||||||
"torch._C.DisableTorchFunctionSubclass",
|
"torch._C.DisableTorchFunctionSubclass",
|
||||||
|
"torch._functorch.vmap.vmap_increment_nesting",
|
||||||
"torch.amp.autocast_mode.autocast",
|
"torch.amp.autocast_mode.autocast",
|
||||||
"torch.autograd.grad_mode.enable_grad",
|
"torch.autograd.grad_mode.enable_grad",
|
||||||
"torch.autograd.grad_mode.inference_mode",
|
"torch.autograd.grad_mode.inference_mode",
|
||||||
@ -396,6 +424,13 @@ torch_c_binding_in_graph_functions = {
|
|||||||
"torch._C._fft.fft_rfft2",
|
"torch._C._fft.fft_rfft2",
|
||||||
"torch._C._fft.fft_rfftfreq",
|
"torch._C._fft.fft_rfftfreq",
|
||||||
"torch._C._fft.fft_rfftn",
|
"torch._C._fft.fft_rfftn",
|
||||||
|
"torch._C._functorch._add_batch_dim",
|
||||||
|
"torch._C._functorch._remove_batch_dim",
|
||||||
|
"torch._C._functorch._vmap_incr_nest",
|
||||||
|
"torch._C._functorch._vmap_decr_nest",
|
||||||
|
"torch._C._functorch._vmap_increment_nesting",
|
||||||
|
"torch._C._functorch._vmap_decrement_nesting",
|
||||||
|
"torch._C._functorch.is_batchedtensor",
|
||||||
"torch._C._free_And_Remove_DeleterFn",
|
"torch._C._free_And_Remove_DeleterFn",
|
||||||
"torch._C._freeze_module",
|
"torch._C._freeze_module",
|
||||||
"torch._C._from_dlpack",
|
"torch._C._from_dlpack",
|
||||||
@ -2135,29 +2170,7 @@ torch_non_c_binding_in_graph_functions = {
|
|||||||
"torch._functorch.utils.enable_single_level_autograd_function",
|
"torch._functorch.utils.enable_single_level_autograd_function",
|
||||||
"torch._functorch.utils.exposed_in",
|
"torch._functorch.utils.exposed_in",
|
||||||
"torch._functorch.utils.unwrap_dead_wrappers",
|
"torch._functorch.utils.unwrap_dead_wrappers",
|
||||||
"torch._functorch.vmap._as_tuple",
|
|
||||||
"torch._functorch.vmap._check_int_or_none",
|
|
||||||
"torch._functorch.vmap._check_out_dims_is_int_or_int_pytree",
|
|
||||||
"torch._functorch.vmap._check_randomness_arg",
|
|
||||||
"torch._functorch.vmap._chunked_vmap",
|
|
||||||
"torch._functorch.vmap._concat_chunked_outputs",
|
|
||||||
"torch._functorch.vmap._create_batched_inputs",
|
|
||||||
"torch._functorch.vmap._flat_vmap",
|
|
||||||
"torch._functorch.vmap._flatten_chunks_output",
|
|
||||||
"torch._functorch.vmap._get_chunked_inputs",
|
|
||||||
"torch._functorch.vmap._get_name",
|
|
||||||
"torch._functorch.vmap._maybe_remove_batch_dim",
|
|
||||||
"torch._functorch.vmap._num_outputs",
|
|
||||||
"torch._functorch.vmap._process_batched_inputs",
|
|
||||||
"torch._functorch.vmap._unwrap_batched",
|
|
||||||
"torch._functorch.vmap._validate_and_get_batch_size",
|
|
||||||
"torch._functorch.vmap.doesnt_support_saved_tensors_hooks",
|
|
||||||
"torch._functorch.vmap.get_chunk_sizes",
|
|
||||||
"torch._functorch.vmap.lazy_load_decompositions",
|
"torch._functorch.vmap.lazy_load_decompositions",
|
||||||
"torch._functorch.vmap.restore_vmap",
|
|
||||||
"torch._functorch.vmap.unwrap_batched",
|
|
||||||
"torch._functorch.vmap.vmap_impl",
|
|
||||||
"torch._functorch.vmap.wrap_batched",
|
|
||||||
"torch._guards.compile_context",
|
"torch._guards.compile_context",
|
||||||
"torch._guards.detect_fake_mode",
|
"torch._guards.detect_fake_mode",
|
||||||
"torch._guards.tracing",
|
"torch._guards.tracing",
|
||||||
|
@ -9,6 +9,7 @@ from .ctx_manager import (
|
|||||||
InferenceModeVariable,
|
InferenceModeVariable,
|
||||||
StreamContextVariable,
|
StreamContextVariable,
|
||||||
StreamVariable,
|
StreamVariable,
|
||||||
|
VmapIncrementNestingCtxManagerVariable,
|
||||||
WithExitFunctionVariable,
|
WithExitFunctionVariable,
|
||||||
)
|
)
|
||||||
from .dicts import (
|
from .dicts import (
|
||||||
@ -23,7 +24,10 @@ from .functions import (
|
|||||||
UserFunctionVariable,
|
UserFunctionVariable,
|
||||||
UserMethodVariable,
|
UserMethodVariable,
|
||||||
)
|
)
|
||||||
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
from .higher_order_ops import (
|
||||||
|
FunctorchVmapHigherOrderVariable,
|
||||||
|
TorchHigherOrderOperatorVariable,
|
||||||
|
)
|
||||||
from .iter import (
|
from .iter import (
|
||||||
CountIteratorVariable,
|
CountIteratorVariable,
|
||||||
CycleIteratorVariable,
|
CycleIteratorVariable,
|
||||||
|
@ -1507,6 +1507,9 @@ def wrap_fx_proxy_cls(
|
|||||||
torch._utils._element_size,
|
torch._utils._element_size,
|
||||||
torch.seed,
|
torch.seed,
|
||||||
operator.mod,
|
operator.mod,
|
||||||
|
torch._C._functorch._vmap_increment_nesting,
|
||||||
|
torch._C._functorch._vmap_decrement_nesting,
|
||||||
|
torch._functorch.vmap._validate_and_get_batch_size,
|
||||||
# some mac builds are missing torch.distributed.get_rank()
|
# some mac builds are missing torch.distributed.get_rank()
|
||||||
getattr(torch.distributed, "get_rank", _missing),
|
getattr(torch.distributed, "get_rank", _missing),
|
||||||
getattr(torch.distributed, "get_world_size", _missing),
|
getattr(torch.distributed, "get_world_size", _missing),
|
||||||
|
@ -106,6 +106,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
chr,
|
chr,
|
||||||
divmod,
|
divmod,
|
||||||
float,
|
float,
|
||||||
|
getattr,
|
||||||
int,
|
int,
|
||||||
len,
|
len,
|
||||||
max,
|
max,
|
||||||
|
@ -148,6 +148,48 @@ class GenericContextWrappingVariable(ContextWrappingVariable):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
||||||
|
"""represents torch VMap increment/decrement nesting"""
|
||||||
|
|
||||||
|
# A guard is needed as the vmap level is baked into the torch FX graph
|
||||||
|
# generated. This is fine if vmap is only called from within the function
|
||||||
|
# being compiled. But the FX graph may be invalid in the case of a vmap
|
||||||
|
# call from eager that calls the compiled function, as the vmap levels
|
||||||
|
# may be different.
|
||||||
|
_guards_singleton = Guard(
|
||||||
|
GlobalStateSource(), GuardBuilder.FUNCTORCH_CURRENT_LEVEL_MATCH
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(tx, target_values, **kwargs):
|
||||||
|
var = VmapIncrementNestingCtxManagerVariable(
|
||||||
|
target_values=target_values,
|
||||||
|
initial_values=None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return var
|
||||||
|
|
||||||
|
def enter(self, tx):
|
||||||
|
install_guard(self._guards_singleton)
|
||||||
|
batch_size, randomness = self.target_values
|
||||||
|
vmap_level = torch._C._functorch._vmap_increment_nesting(batch_size, randomness)
|
||||||
|
self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
|
||||||
|
self.state.proxy = tx.output.create_node(
|
||||||
|
"call_function",
|
||||||
|
torch._C._functorch._vmap_increment_nesting,
|
||||||
|
(batch_size, randomness),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
return variables.ConstantVariable.create(vmap_level)
|
||||||
|
|
||||||
|
def exit(self, tx, *args):
|
||||||
|
self.state.cleanup()
|
||||||
|
tx.output.create_node(
|
||||||
|
"call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
|
||||||
|
)
|
||||||
|
return variables.ConstantVariable.create(None)
|
||||||
|
|
||||||
|
|
||||||
class GradModeVariable(ContextWrappingVariable):
|
class GradModeVariable(ContextWrappingVariable):
|
||||||
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
|
"""represents torch.{no_grad,enable_grad,set_grad_mode}()"""
|
||||||
|
|
||||||
|
@ -105,6 +105,13 @@ class BaseUserFunctionVariable(VariableTracker):
|
|||||||
class UserFunctionVariable(BaseUserFunctionVariable):
|
class UserFunctionVariable(BaseUserFunctionVariable):
|
||||||
"""Some unsupported user-defined global function"""
|
"""Some unsupported user-defined global function"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_with_source(cls, value, source):
|
||||||
|
return cls(
|
||||||
|
value,
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, fn, is_constant=False, **kwargs):
|
def __init__(self, fn, is_constant=False, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if getattr(fn, "_dynamo_marked_constant", False):
|
if getattr(fn, "_dynamo_marked_constant", False):
|
||||||
@ -254,6 +261,10 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
|||||||
def export_freevars(self, parent, child):
|
def export_freevars(self, parent, child):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def call_hasattr(self, tx, name: str) -> VariableTracker:
|
||||||
|
result = hasattr(self.fn, name)
|
||||||
|
return variables.ConstantVariable.create(result)
|
||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
import types
|
import types
|
||||||
|
|
||||||
@ -10,7 +9,6 @@ import torch._C
|
|||||||
import torch.fx
|
import torch.fx
|
||||||
import torch.nn
|
import torch.nn
|
||||||
import torch.onnx.operators
|
import torch.onnx.operators
|
||||||
from torch._dispatch.python import enable_python_dispatcher
|
|
||||||
from torch._dynamo.utils import deepcopy_to_fake_tensor, get_fake_value, get_real_value
|
from torch._dynamo.utils import deepcopy_to_fake_tensor, get_fake_value, get_real_value
|
||||||
from torch._dynamo.variables.base import VariableTracker
|
from torch._dynamo.variables.base import VariableTracker
|
||||||
from torch._dynamo.variables.builtin import BuiltinVariable
|
from torch._dynamo.variables.builtin import BuiltinVariable
|
||||||
@ -504,8 +502,6 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
|
|||||||
return OutDtypeHigherOrderVariable(value, source, **kwargs)
|
return OutDtypeHigherOrderVariable(value, source, **kwargs)
|
||||||
elif value is torch._functorch.eager_transforms.grad_impl:
|
elif value is torch._functorch.eager_transforms.grad_impl:
|
||||||
return FunctorchGradHigherOrderVariable(value, source, **kwargs)
|
return FunctorchGradHigherOrderVariable(value, source, **kwargs)
|
||||||
elif value is torch._functorch.vmap.vmap_impl:
|
|
||||||
return FunctorchVmapHigherOrderVariable(value, source, **kwargs)
|
|
||||||
elif value.__name__ == "wrap":
|
elif value.__name__ == "wrap":
|
||||||
return WrapHigherOrderVariable(value, source, **kwargs)
|
return WrapHigherOrderVariable(value, source, **kwargs)
|
||||||
elif value.__name__ in (
|
elif value.__name__ in (
|
||||||
@ -1031,165 +1027,17 @@ class FunctorchGradHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||||||
return TupleVariable([TupleVariable(items), aux])
|
return TupleVariable([TupleVariable(items), aux])
|
||||||
|
|
||||||
|
|
||||||
class FunctorchVmapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
class FunctorchVmapHigherOrderVariable(UserFunctionVariable):
|
||||||
def call_function(
|
def call_function(
|
||||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
from . import ConstantVariable, TensorVariable
|
|
||||||
from .builder import wrap_fx_proxy
|
|
||||||
|
|
||||||
if not torch._dynamo.config.capture_func_transforms:
|
if not torch._dynamo.config.capture_func_transforms:
|
||||||
unimplemented(
|
unimplemented(
|
||||||
"torch.func.vmap capture is disabled, "
|
"torch.func.vmap capture is disabled, "
|
||||||
"it can be turned on by setting "
|
"it can be turned on by setting "
|
||||||
"`torch._dynamo.config.capture_func_transforms=True`"
|
"`torch._dynamo.config.capture_func_transforms=True`"
|
||||||
)
|
)
|
||||||
|
return super().call_function(tx, args, kwargs)
|
||||||
# unpack args
|
|
||||||
fn = args[0]
|
|
||||||
in_dims = args[1]
|
|
||||||
out_dims = args[2]
|
|
||||||
randomness = args[3]
|
|
||||||
chunk_size = args[4]
|
|
||||||
batch_input_args = args[5:]
|
|
||||||
|
|
||||||
if not isinstance(in_dims, (ConstantVariable, TupleVariable)):
|
|
||||||
unimplemented("torch.func.vmap: in_dims is not an int or tuple variable.")
|
|
||||||
|
|
||||||
if not isinstance(out_dims, (ConstantVariable, TupleVariable)):
|
|
||||||
unimplemented("torch.func.vmap: out_dims is not an int or tuple variable.")
|
|
||||||
|
|
||||||
if len(kwargs) > 0:
|
|
||||||
unimplemented(
|
|
||||||
"NYI - torch.func.vmap: kwargs arguments are currently unsupported."
|
|
||||||
)
|
|
||||||
|
|
||||||
if chunk_size.value is not None:
|
|
||||||
unimplemented(
|
|
||||||
"NYI - torch.func.vmap is not implemented when chunk_size is passed"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Trace into tree_flatten with the list of batch_input_args.
|
|
||||||
flat_args, arg_spec = _make_inlined(tx, pytree.tree_flatten)(
|
|
||||||
ListVariable(batch_input_args)
|
|
||||||
).unpack_var_sequence(tx)
|
|
||||||
|
|
||||||
# Transform in_dims into a list if it's not an integer literal.
|
|
||||||
in_dims_v = (
|
|
||||||
in_dims
|
|
||||||
if isinstance(in_dims.as_python_constant(), int)
|
|
||||||
else BuiltinVariable(list).call_function(tx, [in_dims], {})
|
|
||||||
)
|
|
||||||
|
|
||||||
# Trace into _broadcast_to_and_flatten with the transformed in_dims.
|
|
||||||
broadcasted_in_dims = _make_inlined(tx, pytree._broadcast_to_and_flatten)(
|
|
||||||
in_dims_v, arg_spec
|
|
||||||
)
|
|
||||||
|
|
||||||
# We want to pass unbatched input to speculate subgraph.
|
|
||||||
# So we loop through the inputs and select only one sample
|
|
||||||
# from the batch.
|
|
||||||
unbatched_input_args = []
|
|
||||||
for arg, in_dim in zip(
|
|
||||||
flat_args.unpack_var_sequence(tx),
|
|
||||||
broadcasted_in_dims.unpack_var_sequence(tx),
|
|
||||||
):
|
|
||||||
if in_dim is not None:
|
|
||||||
assert isinstance(arg, TensorVariable)
|
|
||||||
unbatched_arg = arg.call_method(
|
|
||||||
tx, "select", [in_dim, ConstantVariable.create(0)], {}
|
|
||||||
)
|
|
||||||
unbatched_input_args.append(unbatched_arg)
|
|
||||||
else:
|
|
||||||
unbatched_input_args.append(arg)
|
|
||||||
|
|
||||||
# Ban ops like `stride`, `storage_offset` in the traced functions.
|
|
||||||
# NOTE: We are conservatively banning more ops (vmap should be able
|
|
||||||
# to handle a few of them).
|
|
||||||
with tx.strict_translation_mode():
|
|
||||||
# trace through the function with unbatched inputs.
|
|
||||||
_, body_graph, body_lifted_freevars = speculate_subgraph(
|
|
||||||
tx,
|
|
||||||
fn,
|
|
||||||
# Returns a ListVariable, since that's where we started flattening.
|
|
||||||
# However, we really want to pass the inner Python list as argument.
|
|
||||||
_make_inlined(tx, pytree.tree_unflatten)(
|
|
||||||
ListVariable(unbatched_input_args), arg_spec
|
|
||||||
).unpack_var_sequence(tx),
|
|
||||||
{},
|
|
||||||
"torch.vmap",
|
|
||||||
source_target=self.value,
|
|
||||||
set_subgraph_inputs="manual",
|
|
||||||
)
|
|
||||||
|
|
||||||
body_name = add_subgraph(
|
|
||||||
tx,
|
|
||||||
self.source,
|
|
||||||
"vmap_body",
|
|
||||||
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
|
|
||||||
)
|
|
||||||
body_node = make_attr(tx, body_name)
|
|
||||||
|
|
||||||
# body_lifted_variable should not be treated as batched.
|
|
||||||
# So here we update `in_dims` to reflect that.
|
|
||||||
# NOTE: updated_in_dims is flat list, it is ok for now
|
|
||||||
# as speculate_subgraph does not supports functions with non-Tensor args.
|
|
||||||
# (so we graph-break above)
|
|
||||||
updated_in_dims = TupleVariable(
|
|
||||||
broadcasted_in_dims.unpack_var_sequence(tx)
|
|
||||||
+ [
|
|
||||||
ConstantVariable.create(None),
|
|
||||||
]
|
|
||||||
* len(body_lifted_freevars)
|
|
||||||
)
|
|
||||||
|
|
||||||
vmap_proxy_args = (
|
|
||||||
body_node,
|
|
||||||
*(arg.as_proxy() for arg in (updated_in_dims, out_dims, randomness)),
|
|
||||||
)
|
|
||||||
# vmap_proxy corresponds to `vmap_proxy = vmap(fn, *vmap_args, **vmap_kwargs)`
|
|
||||||
vmap_proxy = tx.output.create_proxy(
|
|
||||||
"call_function",
|
|
||||||
torch.func.vmap,
|
|
||||||
args=tuple(vmap_proxy_args),
|
|
||||||
kwargs={},
|
|
||||||
name="vmap_proxy",
|
|
||||||
)
|
|
||||||
|
|
||||||
proxy_batched_fn_args = tuple(
|
|
||||||
arg.as_proxy() for arg in batch_input_args
|
|
||||||
) + tuple(body_lifted_freevars)
|
|
||||||
|
|
||||||
# We compute the example_value by actually calling
|
|
||||||
# `vmap` with FakeTensors.
|
|
||||||
fake_batched_fn_args = itertools.chain(
|
|
||||||
(get_fake_value(arg.as_proxy().node, tx) for arg in batch_input_args),
|
|
||||||
(get_fake_value(arg.node, tx) for arg in body_lifted_freevars),
|
|
||||||
)
|
|
||||||
actual_in_dims = tuple(
|
|
||||||
pytree.tree_map(lambda x: x.value, updated_in_dims.items)
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: `body_graph` might have operators which
|
|
||||||
# will create new tensors. So it is required
|
|
||||||
# that we run `vmap` under FakeMode.
|
|
||||||
with tx.fake_mode, enable_python_dispatcher():
|
|
||||||
example_value = torch._functorch.vmap.vmap_impl(
|
|
||||||
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
|
|
||||||
actual_in_dims,
|
|
||||||
out_dims.as_python_constant(),
|
|
||||||
randomness.value,
|
|
||||||
chunk_size.value,
|
|
||||||
*fake_batched_fn_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
# proxy corresponds to `call = vmap_proxy(*batched_fn_args, **batched_fn_kwargs)`
|
|
||||||
proxy = vmap_proxy(*proxy_batched_fn_args)
|
|
||||||
return wrap_fx_proxy(
|
|
||||||
tx=tx,
|
|
||||||
proxy=proxy,
|
|
||||||
example_value=example_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
@ -27,6 +27,7 @@ from ..guards import GuardBuilder
|
|||||||
from ..utils import (
|
from ..utils import (
|
||||||
check_constant_args,
|
check_constant_args,
|
||||||
check_unspec_python_args,
|
check_unspec_python_args,
|
||||||
|
guard_if_dyn,
|
||||||
has_torch_function,
|
has_torch_function,
|
||||||
product,
|
product,
|
||||||
proxy_args_kwargs,
|
proxy_args_kwargs,
|
||||||
@ -140,7 +141,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
|||||||
def call_function(
|
def call_function(
|
||||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
from . import GradModeVariable, InferenceModeVariable, StreamVariable
|
from . import (
|
||||||
|
GradModeVariable,
|
||||||
|
InferenceModeVariable,
|
||||||
|
StreamVariable,
|
||||||
|
VmapIncrementNestingCtxManagerVariable,
|
||||||
|
)
|
||||||
|
|
||||||
if self.value is torch.no_grad:
|
if self.value is torch.no_grad:
|
||||||
if len(args) == 1 and isinstance(
|
if len(args) == 1 and isinstance(
|
||||||
@ -193,6 +199,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
|||||||
elif self.value is torch._C.DisableTorchFunctionSubclass:
|
elif self.value is torch._C.DisableTorchFunctionSubclass:
|
||||||
assert not (args or kwargs)
|
assert not (args or kwargs)
|
||||||
return TorchFunctionDisableVariable.create(tx)
|
return TorchFunctionDisableVariable.create(tx)
|
||||||
|
elif self.value is torch._functorch.vmap.vmap_increment_nesting:
|
||||||
|
assert len(args) == 2
|
||||||
|
return VmapIncrementNestingCtxManagerVariable.create(
|
||||||
|
tx,
|
||||||
|
[guard_if_dyn(x) for x in args],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TorchInGraphFunctionVariable(BaseTorchVariable):
|
class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||||
@ -238,10 +250,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
):
|
):
|
||||||
tx.mark_inconsistent_side_effects()
|
tx.mark_inconsistent_side_effects()
|
||||||
return ConstantVariable.create(tracing_state_functions[self.value])
|
return ConstantVariable.create(tracing_state_functions[self.value])
|
||||||
elif self.value in (
|
elif self.value in (torch._functorch.eager_transforms.grad_impl,):
|
||||||
torch._functorch.vmap.vmap_impl,
|
|
||||||
torch._functorch.eager_transforms.grad_impl,
|
|
||||||
):
|
|
||||||
return TorchHigherOrderOperatorVariable.make(
|
return TorchHigherOrderOperatorVariable.make(
|
||||||
self.value,
|
self.value,
|
||||||
source=self.source,
|
source=self.source,
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import threading
|
import threading
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -274,17 +275,10 @@ def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):
|
|||||||
return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
|
return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
|
||||||
args_spec, out_dims, randomness, **kwargs)
|
args_spec, out_dims, randomness, **kwargs)
|
||||||
|
|
||||||
from torch._dynamo import disable
|
# If chunk_size is not specified.
|
||||||
|
return _flat_vmap(
|
||||||
# remove @disable once #114306 is fixed
|
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
|
||||||
@disable
|
)
|
||||||
def wrapper():
|
|
||||||
# If chunk_size is not specified.
|
|
||||||
return _flat_vmap(
|
|
||||||
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return wrapper()
|
|
||||||
|
|
||||||
def get_chunk_sizes(total_elems, chunk_size):
|
def get_chunk_sizes(total_elems, chunk_size):
|
||||||
n_chunks = n_chunks = total_elems // chunk_size
|
n_chunks = n_chunks = total_elems // chunk_size
|
||||||
@ -390,15 +384,22 @@ def _check_randomness_arg(randomness):
|
|||||||
raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}")
|
raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}")
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def vmap_increment_nesting(batch_size, randomness):
|
||||||
|
try:
|
||||||
|
vmap_level = _vmap_increment_nesting(batch_size, randomness)
|
||||||
|
yield vmap_level
|
||||||
|
finally:
|
||||||
|
_vmap_decrement_nesting()
|
||||||
|
|
||||||
|
|
||||||
@doesnt_support_saved_tensors_hooks
|
@doesnt_support_saved_tensors_hooks
|
||||||
def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs):
|
def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs):
|
||||||
vmap_level = _vmap_increment_nesting(batch_size, randomness)
|
|
||||||
try:
|
with vmap_increment_nesting(batch_size, randomness) as vmap_level:
|
||||||
batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
|
batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
|
||||||
batched_outputs = func(*batched_inputs, **kwargs)
|
batched_outputs = func(*batched_inputs, **kwargs)
|
||||||
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
|
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
|
||||||
finally:
|
|
||||||
_vmap_decrement_nesting()
|
|
||||||
|
|
||||||
|
|
||||||
# `restore_vmap` is a private helper function. It is vmap but has the following
|
# `restore_vmap` is a private helper function. It is vmap but has the following
|
||||||
@ -424,13 +425,10 @@ def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, r
|
|||||||
@doesnt_support_saved_tensors_hooks
|
@doesnt_support_saved_tensors_hooks
|
||||||
def restore_vmap(func, in_dims, batch_size, randomness):
|
def restore_vmap(func, in_dims, batch_size, randomness):
|
||||||
def inner(*args, **kwargs):
|
def inner(*args, **kwargs):
|
||||||
vmap_level = _vmap_increment_nesting(batch_size, randomness)
|
with vmap_increment_nesting(batch_size, randomness) as vmap_level:
|
||||||
try:
|
|
||||||
batched_inputs = wrap_batched(args, in_dims, vmap_level)
|
batched_inputs = wrap_batched(args, in_dims, vmap_level)
|
||||||
batched_outputs = func(*batched_inputs, **kwargs)
|
batched_outputs = func(*batched_inputs, **kwargs)
|
||||||
return unwrap_batched(batched_outputs, vmap_level)
|
return unwrap_batched(batched_outputs, vmap_level)
|
||||||
finally:
|
|
||||||
_vmap_decrement_nesting()
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
@ -370,6 +370,15 @@ static int64_t currentLevel() {
|
|||||||
return current_level;
|
return current_level;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static c10::optional<int64_t> maybe_current_level() {
|
||||||
|
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||||
|
if (maybe_layer.has_value()) {
|
||||||
|
int current_level = maybe_layer->layerId();
|
||||||
|
return current_level;
|
||||||
|
}
|
||||||
|
return nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
static void tls_set_vmap_excluded(bool excluded) {
|
static void tls_set_vmap_excluded(bool excluded) {
|
||||||
c10::impl::tls_set_dispatch_key_excluded(
|
c10::impl::tls_set_dispatch_key_excluded(
|
||||||
c10::DispatchKey::FuncTorchBatched, excluded);
|
c10::DispatchKey::FuncTorchBatched, excluded);
|
||||||
@ -474,6 +483,7 @@ void initFuncTorchBindings(PyObject* module) {
|
|||||||
m.def("get_unwrapped", &get_unwrapped);
|
m.def("get_unwrapped", &get_unwrapped);
|
||||||
m.def("maybe_get_level", &maybe_get_level);
|
m.def("maybe_get_level", &maybe_get_level);
|
||||||
m.def("maybe_get_bdim", &maybe_get_bdim);
|
m.def("maybe_get_bdim", &maybe_get_bdim);
|
||||||
|
m.def("maybe_current_level", &maybe_current_level);
|
||||||
m.def("current_level", ¤tLevel);
|
m.def("current_level", ¤tLevel);
|
||||||
m.def("tls_set_vmap_excluded", &tls_set_vmap_excluded);
|
m.def("tls_set_vmap_excluded", &tls_set_vmap_excluded);
|
||||||
m.def("_set_dynamic_layer_keys_included", &_set_dynamic_layer_keys_included);
|
m.def("_set_dynamic_layer_keys_included", &_set_dynamic_layer_keys_included);
|
||||||
|
@ -2335,6 +2335,50 @@ dynamo_expected_failures = {
|
|||||||
"TestControlFlowTraced.test_map_functionalized", # functorch/test_control_flow
|
"TestControlFlowTraced.test_map_functionalized", # functorch/test_control_flow
|
||||||
"TestControlFlowTraced.test_nested_map_cond_symbolic", # functorch/test_control_flow
|
"TestControlFlowTraced.test_nested_map_cond_symbolic", # functorch/test_control_flow
|
||||||
"TestControlFlowTraced.test_nested_map_cond_real", # functorch/test_control_flow
|
"TestControlFlowTraced.test_nested_map_cond_real", # functorch/test_control_flow
|
||||||
|
"TestJacCPU.test_against_reference_correctness_different_devices_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_against_reference_default_arg_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_against_reference_multi_input_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_against_reference_multi_input_multi_output_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_against_reference_simple_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_against_reference_unrelated_outputs_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_against_reference_zero_dim_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_argnums_defaults_to_zero_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_aux_pytree_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_chunk_jacrev_composition__preallocate_and_copy_False_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_chunk_jacrev_composition__preallocate_and_copy_True_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_dimensionality_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_empty_output_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_hessian_simple_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_inplace_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_jac_with_non_tensor_args_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_multiple_inputs_outputs_pytree_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_multiple_inputs_pytree_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_multiple_outputs_multiple_argnums_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_multiple_outputs_single_argnums_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_outputs_can_any_pytree_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_unrelated_input_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestJacCPU.test_unrelated_output_jacfwd_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestVmapJvpInplaceViewCPU.test_all_dual_base_inplace_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestVmapJvpInplaceViewCPU.test_all_dual_base_view_inplace_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestVmapJvpInplaceViewCPU.test_all_dual_no_view_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestVmapJvpInplaceViewCPU.test_right_dual_base_prop_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestVmapJvpInplaceViewCPU.test_right_dual_view_prop_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestHessianCPU.test_hessian_vectorize_correctness_multi_input_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestHessianCPU.test_hessian_vectorize_correctness_simple_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestHessianCPU.test_hessian_vectorize_correctness_unrelated_outputs_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestHessianCPU.test_jacfwd_different_levels_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestExamplesCorrectnessCPU.test_ensemble_regression_mechanism_functional_call_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestExamplesCorrectnessCPU.test_ensemble_regression_mechanism_make_functional_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_AlphaDropout_mechanism_functional_call_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_AlphaDropout_mechanism_make_functional_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_Dropout_mechanism_functional_call_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_Dropout_mechanism_make_functional_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_FeatureAlphaDropout_mechanism_functional_call_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestExamplesCorrectnessCPU.test_find_learning_rate_ensembling_FeatureAlphaDropout_mechanism_make_functional_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestHigherOrderOperatorInteractionCPU.test_vmap_grad_sum_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestFunctionalizeCPU.test_multioutput_view_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestFunctionalizeCPU.test_simple_view_cpu", # functorch/test_eager_transforms.py
|
||||||
|
"TestFunctionalizeCPU.test_vmap_functionalize_jvp_cpu", # functorch/test_eager_transforms.py
|
||||||
"TestMetaKernel.test_addmm_invalid_dtype", # lazy/test_meta_kernel
|
"TestMetaKernel.test_addmm_invalid_dtype", # lazy/test_meta_kernel
|
||||||
"TestVerifyCorrectness.test_incorrect_verify_true", # dynamo/test_verify_correctness
|
"TestVerifyCorrectness.test_incorrect_verify_true", # dynamo/test_verify_correctness
|
||||||
"TestVerifyCorrectness.test_torchscript", # dynamo/test_verify_correctness
|
"TestVerifyCorrectness.test_torchscript", # dynamo/test_verify_correctness
|
||||||
@ -2579,7 +2623,11 @@ dynamo_expected_failures = {
|
|||||||
"FuncTorchHigherOrderOpTests.test_vmap_free_const", # dynamo/test_higher_order_ops
|
"FuncTorchHigherOrderOpTests.test_vmap_free_const", # dynamo/test_higher_order_ops
|
||||||
"FuncTorchHigherOrderOpTests.test_vmap_multiple_invocation_in_dims", # dynamo/test_higher_order_ops
|
"FuncTorchHigherOrderOpTests.test_vmap_multiple_invocation_in_dims", # dynamo/test_higher_order_ops
|
||||||
"FuncTorchHigherOrderOpTests.test_grad", # dynamo/test_higher_order_ops
|
"FuncTorchHigherOrderOpTests.test_grad", # dynamo/test_higher_order_ops
|
||||||
"FuncTorchHigherOrderOpTests.test_vmap_illegal_op_graph_break", # dynamo/test_higher_order_ops
|
"FuncTorchHigherOrderOpTests.test_vmap_with_conditional_graph_break", # dynamo/test_higher_order_ops
|
||||||
|
"FuncTorchHigherOrderOpTests.test_vmap_with_graph_break", # dynamo/test_higher_order_ops
|
||||||
|
"FuncTorchHigherOrderOpTests.test_vmap_with_graph_break_2", # dynamo/test_higher_order_ops
|
||||||
|
"FuncTorchHigherOrderOpTests.test_vmap_with_graph_break_lambda", # dynamo/test_higher_order_ops
|
||||||
|
"FuncTorchHigherOrderOpTests.test_vmap_previous_illegal_op_no_graph_break", # dynamo/test_higher_order_ops
|
||||||
"HigherOrderOpTests.test_cond_pytree_operands", # dynamo/test_higher_order_ops
|
"HigherOrderOpTests.test_cond_pytree_operands", # dynamo/test_higher_order_ops
|
||||||
"HigherOrderOpTests.test_cond_branches_no_arguments_no_closure", # dynamo/test_higher_order_ops
|
"HigherOrderOpTests.test_cond_branches_no_arguments_no_closure", # dynamo/test_higher_order_ops
|
||||||
"FuncTorchHigherOrderOpTests.test_vmap_side_effects", # dynamo/test_higher_order_ops
|
"FuncTorchHigherOrderOpTests.test_vmap_side_effects", # dynamo/test_higher_order_ops
|
||||||
@ -2775,6 +2823,8 @@ dynamo_expected_failures = {
|
|||||||
"TestVmapOperatorsLegacy.test_unbind", # test_legacy_vmap
|
"TestVmapOperatorsLegacy.test_unbind", # test_legacy_vmap
|
||||||
"TestVmapAPILegacy.test_non_default_in_dims_out_dims", # test_legacy_vmap
|
"TestVmapAPILegacy.test_non_default_in_dims_out_dims", # test_legacy_vmap
|
||||||
"TestVmapOperatorsLegacy.test_T_numpy", # test_legacy_vmap
|
"TestVmapOperatorsLegacy.test_T_numpy", # test_legacy_vmap
|
||||||
|
"TestNamedTensor.test_expand", # test_namedtensor
|
||||||
|
"TestNamedTensor.test_masked_fill", # test_namedtensor
|
||||||
"TestNamedTensor.test_addmv", # test_namedtensor
|
"TestNamedTensor.test_addmv", # test_namedtensor
|
||||||
"TestNamedTensor.test_cummax_cummin", # test_namedtensor
|
"TestNamedTensor.test_cummax_cummin", # test_namedtensor
|
||||||
"TestNamedTensor.test_no_jit_script_support", # test_namedtensor
|
"TestNamedTensor.test_no_jit_script_support", # test_namedtensor
|
||||||
|
@ -928,7 +928,7 @@ def tree_map_(
|
|||||||
"""
|
"""
|
||||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
||||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||||
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
tuple(map(func, *flat_args)) # consume and exhaust the iterable
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user