mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for tracing vmap in pre-dispatch export (#154650)
Summary: ONNX team and recent transformer upgrade ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops. The implementation strategy is: 1. We add python wrappers around vmap APIs so that we attach custom torch function handler that is only on during non-strict export. The reason is we don't want to add this to default torch_function handler because it will break BC. 2. Some dynamo changes to make sure it picks up new python wrapper APIs. The reason is when we do strict export, we need to re-materialize these APIs in pre-dispatch IR from torch IR. We can avoid this by special casing in dynamo for export to proxy different API calls but i feel that is too much chaos because you need to be able to proxy 2 different variants of same vmap API. Test Plan: CI Differential Revision: D75623875 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154650 Approved by: https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
c5cb255625
commit
dbef606631
1
.flake8
1
.flake8
@ -48,6 +48,7 @@ per-file-ignores =
|
||||
torch/__init__.py: F401,TOR901
|
||||
torch/_custom_op/impl.py: TOR901
|
||||
torch/_export/serde/upgrade.py: TOR901
|
||||
torch/_functorch/predispatch.py: TOR901
|
||||
torch/_functorch/vmap.py: TOR901
|
||||
torch/_inductor/test_operators.py: TOR901
|
||||
torch/_library/abstract_impl.py: TOR901
|
||||
|
@ -292,6 +292,56 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_vmap(self):
|
||||
"""
|
||||
make
|
||||
"""
|
||||
|
||||
def fn(x, y):
|
||||
f = lambda x, y: (x * y + 1).sum(dim=0) # noqa: E731
|
||||
vmapped = torch.vmap(f)(x, y)
|
||||
return vmapped.sum(dim=0)
|
||||
|
||||
x = torch.randn(25, requires_grad=True)
|
||||
y = torch.randn(25, requires_grad=True)
|
||||
x2 = x.detach().clone().requires_grad_(True)
|
||||
y2 = y.detach().clone().requires_grad_(True)
|
||||
|
||||
compiled_fn = torch.compile(fn, backend="inductor")
|
||||
|
||||
# A first call should miss in the cache.
|
||||
self.assertEqual(fn(x, y), compiled_fn(x2, y2))
|
||||
fn(x, y).sum().backward()
|
||||
compiled_fn(x2, y2).sum().backward()
|
||||
self.assertEqual(x.grad, x2.grad)
|
||||
self.assertEqual(y.grad, y2.grad)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
# Reset all tensors
|
||||
x = torch.randn(25, requires_grad=True)
|
||||
y = torch.randn(25, requires_grad=True)
|
||||
x2 = x.detach().clone().requires_grad_(True)
|
||||
y2 = y.detach().clone().requires_grad_(True)
|
||||
|
||||
# A second call should hit. (First reset so in-memory guards
|
||||
# don't prevent compilation).
|
||||
self._clear_dynamo_and_codecache()
|
||||
self.assertEqual(fn(x, y), compiled_fn(x2, y2))
|
||||
fn(x, y).sum().backward()
|
||||
compiled_fn(x2, y2).sum().backward()
|
||||
self.assertEqual(x.grad, x2.grad)
|
||||
self.assertEqual(y.grad, y2.grad)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", True)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
|
@ -3084,29 +3084,29 @@ def forward(self, L_a_ : torch.SymInt, L_b_ : torch.SymInt, L_c_ : torch.SymInt,
|
||||
b = torch.arange(l_b_)
|
||||
c = torch.arange(l_c_)
|
||||
d = torch.arange(l_d_)
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(l_d_, 'error'); _vmap_increment_nesting = None
|
||||
child = torch._C._functorch._add_batch_dim(d, 0, 1); d = None
|
||||
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(l_c_, 'error'); _vmap_increment_nesting_1 = None
|
||||
child_1 = torch._C._functorch._add_batch_dim(c, 0, 2); c = None
|
||||
lazy_load_decompositions_2 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_2 = None
|
||||
_vmap_increment_nesting_2 = torch._C._functorch._vmap_increment_nesting(l_b_, 'error'); _vmap_increment_nesting_2 = None
|
||||
child_2 = torch._C._functorch._add_batch_dim(b, 0, 3); b = None
|
||||
lazy_load_decompositions_3 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_3 = None
|
||||
_vmap_increment_nesting_3 = torch._C._functorch._vmap_increment_nesting(l_a_, 'error'); _vmap_increment_nesting_3 = None
|
||||
_add_batch_dim_3 = torch._C._functorch._add_batch_dim(a, 0, 4); a = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(l_d_, 'error'); _vmap_increment_nesting = None
|
||||
child = torch._functorch.predispatch._add_batch_dim(d, 0, 1); d = None
|
||||
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(l_c_, 'error'); _vmap_increment_nesting_1 = None
|
||||
child_1 = torch._functorch.predispatch._add_batch_dim(c, 0, 2); c = None
|
||||
lazy_load_decompositions_2 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_2 = None
|
||||
_vmap_increment_nesting_2 = torch._functorch.predispatch._vmap_increment_nesting(l_b_, 'error'); _vmap_increment_nesting_2 = None
|
||||
child_2 = torch._functorch.predispatch._add_batch_dim(b, 0, 3); b = None
|
||||
lazy_load_decompositions_3 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_3 = None
|
||||
_vmap_increment_nesting_3 = torch._functorch.predispatch._vmap_increment_nesting(l_a_, 'error'); _vmap_increment_nesting_3 = None
|
||||
_add_batch_dim_3 = torch._functorch.predispatch._add_batch_dim(a, 0, 4); a = None
|
||||
add = _add_batch_dim_3 + child_2; _add_batch_dim_3 = child_2 = None
|
||||
add_1 = add + child_1; add = child_1 = None
|
||||
batched_outputs = add_1 + child; add_1 = child = None
|
||||
batched_outputs_1 = torch._C._functorch._remove_batch_dim(batched_outputs, 4, l_a_, 0); batched_outputs = l_a_ = None
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
batched_outputs_2 = torch._C._functorch._remove_batch_dim(batched_outputs_1, 3, l_b_, 0); batched_outputs_1 = l_b_ = None
|
||||
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
batched_outputs_3 = torch._C._functorch._remove_batch_dim(batched_outputs_2, 2, l_c_, 0); batched_outputs_2 = l_c_ = None
|
||||
_vmap_decrement_nesting_2 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_2 = None
|
||||
_remove_batch_dim_3 = torch._C._functorch._remove_batch_dim(batched_outputs_3, 1, l_d_, 0); batched_outputs_3 = l_d_ = None
|
||||
_vmap_decrement_nesting_3 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_3 = None
|
||||
batched_outputs_1 = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 4, l_a_, 0); batched_outputs = l_a_ = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
batched_outputs_2 = torch._functorch.predispatch._remove_batch_dim(batched_outputs_1, 3, l_b_, 0); batched_outputs_1 = l_b_ = None
|
||||
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
batched_outputs_3 = torch._functorch.predispatch._remove_batch_dim(batched_outputs_2, 2, l_c_, 0); batched_outputs_2 = l_c_ = None
|
||||
_vmap_decrement_nesting_2 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_2 = None
|
||||
_remove_batch_dim_3 = torch._functorch.predispatch._remove_batch_dim(batched_outputs_3, 1, l_d_, 0); batched_outputs_3 = l_d_ = None
|
||||
_vmap_decrement_nesting_3 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_3 = None
|
||||
return (_remove_batch_dim_3,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -3739,11 +3739,11 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
||||
child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
||||
|
||||
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
||||
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
||||
@ -3786,18 +3786,18 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None
|
||||
|
||||
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
|
||||
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None
|
||||
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None
|
||||
|
||||
_add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None
|
||||
_add_batch_dim_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 3); basis = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim_1 = None
|
||||
batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None
|
||||
chunked_result: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
||||
split_1: "f32[12, 4, 3]" = split[0]; split = None
|
||||
@ -3816,9 +3816,9 @@ class GraphModule(torch.nn.Module):
|
||||
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
||||
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
||||
|
||||
results_1: "f32[12, 4, 3, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
||||
results_1: "f32[12, 4, 3, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
||||
|
||||
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
|
||||
movedim: "f32[4, 3, 4, 3, 12]" = results_1.movedim(0, -1); results_1 = None
|
||||
split_2 = movedim.split((12,), dim = -1); movedim = None
|
||||
@ -3867,11 +3867,11 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
||||
child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
||||
|
||||
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
||||
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
||||
@ -3916,18 +3916,18 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None
|
||||
|
||||
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
|
||||
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None
|
||||
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None
|
||||
|
||||
_add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None
|
||||
_add_batch_dim_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 3); basis = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); primals_out = child_4 = _add_batch_dim_1 = None
|
||||
child_5: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
child_6: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None
|
||||
child_6: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
split = child_6.split((12,), dim = 0); child_6 = None
|
||||
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
||||
@ -3947,9 +3947,9 @@ class GraphModule(torch.nn.Module):
|
||||
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
||||
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
||||
|
||||
child_10: "f32[12, 4, 3, 3, 4]" = torch._C._functorch._remove_batch_dim(child_9, 1, 12, 0); child_9 = None
|
||||
child_10: "f32[12, 4, 3, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_9, 1, 12, 0); child_9 = None
|
||||
|
||||
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
|
||||
movedim: "f32[4, 3, 3, 4, 12]" = child_10.movedim(0, -1); child_10 = None
|
||||
split_2 = movedim.split((12,), dim = -1); movedim = None
|
||||
@ -4014,18 +4014,18 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
basis: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
|
||||
_add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 1); basis = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim = None
|
||||
batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
||||
chunked_result: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
||||
split_1: "f32[12, 4, 3]" = split[0]; split = None
|
||||
@ -4092,18 +4092,18 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
|
||||
_add_batch_dim: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 1); basis = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim = None
|
||||
batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
||||
chunked_result: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
||||
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
||||
@ -4172,18 +4172,18 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
|
||||
_add_batch_dim: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 1); basis = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim = None
|
||||
batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
||||
chunked_result: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
split = chunked_result.split((12,), dim = 0); chunked_result = None
|
||||
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
||||
@ -5229,11 +5229,11 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
||||
child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
||||
|
||||
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
||||
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
||||
@ -5259,9 +5259,9 @@ class GraphModule(torch.nn.Module):
|
||||
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
||||
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
||||
|
||||
results: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
||||
results: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
movedim: "f32[4, 3, 12]" = results.movedim(0, -1); results = None
|
||||
split = movedim.split((12,), dim = -1); movedim = None
|
||||
@ -5310,11 +5310,11 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
||||
child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
||||
|
||||
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
||||
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
||||
@ -5341,9 +5341,9 @@ class GraphModule(torch.nn.Module):
|
||||
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
||||
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
||||
|
||||
results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
||||
results: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
movedim: "f32[3, 4, 12]" = results.movedim(0, -1); results = None
|
||||
split = movedim.split((12,), dim = -1); movedim = None
|
||||
@ -5392,11 +5392,11 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
||||
child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
||||
|
||||
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
||||
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
||||
@ -5425,10 +5425,10 @@ class GraphModule(torch.nn.Module):
|
||||
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
||||
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
||||
|
||||
results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
||||
aux_2: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(aux_1, 1, 12, 0); aux_1 = None
|
||||
results: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None
|
||||
aux_2: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(aux_1, 1, 12, 0); aux_1 = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
aux_3: "f32[4, 3]" = aux_2[0]; aux_2 = None
|
||||
|
||||
@ -5479,11 +5479,11 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'same'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'same'); _vmap_increment_nesting = None
|
||||
|
||||
child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
||||
child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
||||
|
||||
_jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None
|
||||
_set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None
|
||||
@ -5517,10 +5517,10 @@ class GraphModule(torch.nn.Module):
|
||||
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
|
||||
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
|
||||
|
||||
child_8: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_6, 1, 12, 0); child_6 = None
|
||||
child_9: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(child_7, 1, 12, 0); child_7 = None
|
||||
child_8: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_6, 1, 12, 0); child_6 = None
|
||||
child_9: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(child_7, 1, 12, 0); child_7 = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
movedim: "f32[3, 4, 12]" = child_8.movedim(0, -1); child_8 = None
|
||||
split = movedim.split((12,), dim = -1); movedim = None
|
||||
@ -6260,19 +6260,19 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[3, 3, 3]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
_add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
|
||||
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
||||
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||
batched_outputs: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||
|
||||
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
||||
_remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
return (_remove_batch_dim,)
|
||||
""",
|
||||
)
|
||||
@ -6298,20 +6298,20 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[3, 3, 3]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
_add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
|
||||
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
||||
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||
batched_outputs: "f32[3]" = add + 3; add = None
|
||||
|
||||
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
||||
_remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
return (_remove_batch_dim,)
|
||||
""",
|
||||
)
|
||||
@ -6338,20 +6338,20 @@ class GraphModule(torch.nn.Module):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
_add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
|
||||
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
||||
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||
batched_outputs: "f32[3, 3]" = add + l_y_; add = l_y_ = None
|
||||
|
||||
_remove_batch_dim: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
||||
_remove_batch_dim: "f32[3, 3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
return (_remove_batch_dim,)
|
||||
""",
|
||||
)
|
||||
@ -6379,21 +6379,21 @@ class GraphModule(torch.nn.Module):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
_add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None
|
||||
_add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
_add_batch_dim_1: "f32[3]" = torch._functorch.predispatch._add_batch_dim(l_y_, 1, 1); l_y_ = None
|
||||
|
||||
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
||||
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||
batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
|
||||
|
||||
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
||||
_remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
return (_remove_batch_dim,)
|
||||
""",
|
||||
)
|
||||
@ -6423,21 +6423,21 @@ class GraphModule(torch.nn.Module):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
_add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None
|
||||
_add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
_add_batch_dim_1: "f32[3]" = torch._functorch.predispatch._add_batch_dim(l_y_, 1, 1); l_y_ = None
|
||||
|
||||
sum_1: "f32[3]" = _add_batch_dim.sum(0)
|
||||
sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||
add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||
batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
|
||||
|
||||
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
||||
_remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
return (_remove_batch_dim,)
|
||||
""",
|
||||
)
|
||||
@ -6463,29 +6463,29 @@ class GraphModule(torch.nn.Module):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
child: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
child_1: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None
|
||||
child: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
child_1: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_y_, 0, 1); l_y_ = None
|
||||
|
||||
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
|
||||
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None
|
||||
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None
|
||||
|
||||
_add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(child, 1, 2); child = None
|
||||
_add_batch_dim_3: "f32[3]" = torch._C._functorch._add_batch_dim(child_1, 1, 2); child_1 = None
|
||||
_add_batch_dim_2: "f32[3]" = torch._functorch.predispatch._add_batch_dim(child, 1, 2); child = None
|
||||
_add_batch_dim_3: "f32[3]" = torch._functorch.predispatch._add_batch_dim(child_1, 1, 2); child_1 = None
|
||||
|
||||
batched_outputs: "f32[3]" = _add_batch_dim_2 + _add_batch_dim_3; _add_batch_dim_2 = _add_batch_dim_3 = None
|
||||
|
||||
batched_outputs_1: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
|
||||
batched_outputs_1: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
_remove_batch_dim_1: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 3, 0); batched_outputs_1 = None
|
||||
_remove_batch_dim_1: "f32[3, 3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs_1, 1, 3, 0); batched_outputs_1 = None
|
||||
|
||||
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
return (_remove_batch_dim_1,)
|
||||
""",
|
||||
)
|
||||
@ -6512,27 +6512,27 @@ class GraphModule(torch.nn.Module):
|
||||
l_y_ = L_y_
|
||||
l_x_ = L_x_
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(5, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
child: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None
|
||||
child: "f32[3]" = torch._functorch.predispatch._add_batch_dim(l_y_, 0, 1); l_y_ = None
|
||||
|
||||
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
|
||||
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None
|
||||
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None
|
||||
|
||||
_add_batch_dim_1: "f32[]" = torch._C._functorch._add_batch_dim(child, 0, 2); child = None
|
||||
_add_batch_dim_1: "f32[]" = torch._functorch.predispatch._add_batch_dim(child, 0, 2); child = None
|
||||
|
||||
batched_outputs: "f32[2, 3]" = l_x_ * _add_batch_dim_1; l_x_ = _add_batch_dim_1 = None
|
||||
|
||||
batched_outputs_1: "f32[3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
|
||||
batched_outputs_1: "f32[3, 2, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
_remove_batch_dim_1: "f32[5, 3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 5, 0); batched_outputs_1 = None
|
||||
_remove_batch_dim_1: "f32[5, 3, 2, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs_1, 1, 5, 0); batched_outputs_1 = None
|
||||
|
||||
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
return (_remove_batch_dim_1,)
|
||||
""",
|
||||
)
|
||||
@ -6557,19 +6557,19 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[2, 4, 3]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
_add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
|
||||
child: "f32[3]" = _add_batch_dim.sum(0)
|
||||
child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||
|
||||
_remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 0); child = None
|
||||
_remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
||||
_remove_batch_dim: "f32[2, 3]" = torch._functorch.predispatch._remove_batch_dim(child, 1, 2, 0); child = None
|
||||
_remove_batch_dim_1: "f32[2, 4]" = torch._functorch.predispatch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
return (_remove_batch_dim, _remove_batch_dim_1)
|
||||
""",
|
||||
)
|
||||
@ -6594,19 +6594,19 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[2, 4, 3]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
_add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
|
||||
child: "f32[3]" = _add_batch_dim.sum(0)
|
||||
child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||
|
||||
_remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None
|
||||
_remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
||||
_remove_batch_dim: "f32[3, 2]" = torch._functorch.predispatch._remove_batch_dim(child, 1, 2, 1); child = None
|
||||
_remove_batch_dim_1: "f32[2, 4]" = torch._functorch.predispatch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
return (_remove_batch_dim, _remove_batch_dim_1)
|
||||
""",
|
||||
)
|
||||
@ -6632,19 +6632,19 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[2, 4, 3]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
_add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None
|
||||
|
||||
child: "f32[3]" = _add_batch_dim.sum(0)
|
||||
child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
|
||||
|
||||
_remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None
|
||||
_remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
||||
_remove_batch_dim: "f32[3, 2]" = torch._functorch.predispatch._remove_batch_dim(child, 1, 2, 1); child = None
|
||||
_remove_batch_dim_1: "f32[2, 4]" = torch._functorch.predispatch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
return (_remove_batch_dim, _remove_batch_dim_1)
|
||||
""",
|
||||
)
|
||||
|
@ -2547,6 +2547,67 @@ graph():
|
||||
res = ep.module()(ref_x)
|
||||
self.assertEqual(res, ref_out)
|
||||
|
||||
@testing.expectedFailureSerDer # can't serialize functorch ops
|
||||
@testing.expectedFailureSerDerNonStrict # can't serialize functorch ops
|
||||
@testing.expectedFailureCppRuntime
|
||||
def test_vmap(self):
|
||||
class Vmap(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
f = lambda x, y: (x * y + 1).sum(dim=0) # noqa: E731
|
||||
vmapped = torch.vmap(f)(x, y)
|
||||
return vmapped.sum(dim=0)
|
||||
|
||||
DYN = torch.export.Dim.DYNAMIC
|
||||
inputs = (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))
|
||||
dynamic = {"x": {0: DYN}, "y": {0: DYN}}
|
||||
ep = torch.export.export(Vmap(), inputs, {}, dynamic_shapes=dynamic)
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%y : [num_users=2] = placeholder[target=y]
|
||||
%sym_size_int_3 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
|
||||
%lazy_load_decompositions : [num_users=0] = call_function[target=torch._functorch.predispatch.lazy_load_decompositions](args = (), kwargs = {})
|
||||
%_vmap_increment_nesting : [num_users=0] = call_function[target=torch._functorch.predispatch._vmap_increment_nesting](args = (%sym_size_int_3, error), kwargs = {})
|
||||
%_add_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%x, 0, 1), kwargs = {})
|
||||
%_add_batch_dim_1 : [num_users=1] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%y, 0, 1), kwargs = {})
|
||||
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_add_batch_dim, %_add_batch_dim_1), kwargs = {})
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, 1), kwargs = {})
|
||||
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%add, [0]), kwargs = {})
|
||||
%_remove_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%sum_1, 1, %sym_size_int_3, 0), kwargs = {})
|
||||
%_vmap_decrement_nesting : [num_users=0] = call_function[target=torch._functorch.predispatch._vmap_decrement_nesting](args = (), kwargs = {})
|
||||
%sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%_remove_batch_dim, [0]), kwargs = {})
|
||||
return (sum_2,)""",
|
||||
)
|
||||
ep = torch.export.export(
|
||||
Vmap(), inputs, {}, dynamic_shapes=dynamic, strict=True
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%y : [num_users=2] = placeholder[target=y]
|
||||
%sym_size_int_2 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
|
||||
%lazy_load_decompositions : [num_users=0] = call_function[target=torch._functorch.predispatch.lazy_load_decompositions](args = (), kwargs = {})
|
||||
%_vmap_increment_nesting : [num_users=0] = call_function[target=torch._functorch.predispatch._vmap_increment_nesting](args = (%sym_size_int_2, error), kwargs = {})
|
||||
%_add_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%x, 0, 1), kwargs = {})
|
||||
%_add_batch_dim_1 : [num_users=1] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%y, 0, 1), kwargs = {})
|
||||
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_add_batch_dim, %_add_batch_dim_1), kwargs = {})
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, 1), kwargs = {})
|
||||
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%add, [0]), kwargs = {})
|
||||
%_remove_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%sum_1, 1, %sym_size_int_2, 0), kwargs = {})
|
||||
%_vmap_decrement_nesting : [num_users=0] = call_function[target=torch._functorch.predispatch._vmap_decrement_nesting](args = (), kwargs = {})
|
||||
%sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%_remove_batch_dim, [0]), kwargs = {})
|
||||
return (sum_2,)""",
|
||||
)
|
||||
self.assertTrue(torch.allclose(ep.module()(*inputs), Vmap()(*inputs)))
|
||||
ep = export(Vmap(), inputs, {}, dynamic_shapes=dynamic).run_decompositions({})
|
||||
self.assertTrue(torch.allclose(ep.module()(*inputs), Vmap()(*inputs)))
|
||||
|
||||
@testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses
|
||||
@testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses
|
||||
def test_subclass_nested_attr_access(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -4061,53 +4061,53 @@ class GraphModule(torch.nn.Module):
|
||||
child_4: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 1, None, 2)
|
||||
child_5: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 1, None, 2)
|
||||
|
||||
lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None
|
||||
|
||||
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting = None
|
||||
_vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting = None
|
||||
|
||||
_add_batch_dim: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
|
||||
_add_batch_dim_1: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_1, 0, 1); child_1 = None
|
||||
_add_batch_dim_2: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_2, 0, 1); child_2 = _add_batch_dim_2 = None
|
||||
_add_batch_dim_3: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_3, 0, 1); child_3 = _add_batch_dim_3 = None
|
||||
_add_batch_dim_4: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_4, 0, 1); child_4 = _add_batch_dim_4 = None
|
||||
_add_batch_dim_5: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_5, 0, 1); child_5 = None
|
||||
_add_batch_dim: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None
|
||||
_add_batch_dim_1: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_1, 0, 1); child_1 = None
|
||||
_add_batch_dim_2: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_2, 0, 1); child_2 = _add_batch_dim_2 = None
|
||||
_add_batch_dim_3: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_3, 0, 1); child_3 = _add_batch_dim_3 = None
|
||||
_add_batch_dim_4: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_4, 0, 1); child_4 = _add_batch_dim_4 = None
|
||||
_add_batch_dim_5: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_5, 0, 1); child_5 = None
|
||||
|
||||
a: "f32[10, 2]" = _add_batch_dim + _add_batch_dim_5; _add_batch_dim = None
|
||||
b: "f32[10, 2]" = _add_batch_dim_1 - _add_batch_dim_5; _add_batch_dim_1 = _add_batch_dim_5 = None
|
||||
|
||||
child_6: "f32[10, 2]" = a - b
|
||||
|
||||
child_7: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(a, 1, 1, 0); a = None
|
||||
child_8: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(b, 1, 1, 0); b = None
|
||||
child_9: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(child_6, 1, 1, 0); child_6 = None
|
||||
child_7: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(a, 1, 1, 0); a = None
|
||||
child_8: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(b, 1, 1, 0); b = None
|
||||
child_9: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(child_6, 1, 1, 0); child_6 = None
|
||||
|
||||
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
_vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
|
||||
|
||||
child_10: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 2, None, 2)
|
||||
child_11: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 2, None, 2)
|
||||
child_12: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 2, None, 2)
|
||||
|
||||
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None
|
||||
|
||||
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_1 = None
|
||||
_vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_1 = None
|
||||
|
||||
_add_batch_dim_6: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_7, 0, 1)
|
||||
_add_batch_dim_7: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_8, 0, 1)
|
||||
_add_batch_dim_8: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_9, 0, 1); _add_batch_dim_8 = None
|
||||
_add_batch_dim_9: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_10, 0, 1); child_10 = _add_batch_dim_9 = None
|
||||
_add_batch_dim_10: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_11, 0, 1); child_11 = _add_batch_dim_10 = None
|
||||
_add_batch_dim_11: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_12, 0, 1); child_12 = None
|
||||
_add_batch_dim_6: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_7, 0, 1)
|
||||
_add_batch_dim_7: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_8, 0, 1)
|
||||
_add_batch_dim_8: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_9, 0, 1); _add_batch_dim_8 = None
|
||||
_add_batch_dim_9: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_10, 0, 1); child_10 = _add_batch_dim_9 = None
|
||||
_add_batch_dim_10: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_11, 0, 1); child_11 = _add_batch_dim_10 = None
|
||||
_add_batch_dim_11: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_12, 0, 1); child_12 = None
|
||||
|
||||
a_1: "f32[10, 2]" = _add_batch_dim_6 + _add_batch_dim_11; _add_batch_dim_6 = None
|
||||
b_1: "f32[10, 2]" = _add_batch_dim_7 - _add_batch_dim_11; _add_batch_dim_7 = _add_batch_dim_11 = None
|
||||
|
||||
child_13: "f32[10, 2]" = a_1 - b_1
|
||||
|
||||
child_14: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(a_1, 1, 1, 0); a_1 = None
|
||||
child_15: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(b_1, 1, 1, 0); b_1 = None
|
||||
child_16: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(child_13, 1, 1, 0); child_13 = None
|
||||
child_14: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(a_1, 1, 1, 0); a_1 = None
|
||||
child_15: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(b_1, 1, 1, 0); b_1 = None
|
||||
child_16: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(child_13, 1, 1, 0); child_13 = None
|
||||
|
||||
_vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
_vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None
|
||||
|
||||
slice_10: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 0, 1); elem_3 = None
|
||||
cat: "f32[2, 10, 2]" = torch.cat([slice_10, child_14], dim = 0); slice_10 = child_14 = None
|
||||
|
@ -243,6 +243,8 @@ manual_torch_name_rule_map: dict[
|
||||
"torch._C.set_autocast_xla_dtype": SkipFunctionVariable,
|
||||
"torch._C.set_autocast_xla_enabled": SkipFunctionVariable,
|
||||
"torch.resize_as_": SkipFunctionVariable,
|
||||
"torch._functorch.predispatch._add_batch_dim": TorchInGraphFunctionVariable,
|
||||
"torch._functorch.predispatch._remove_batch_dim": TorchInGraphFunctionVariable,
|
||||
"torch.resize_as_sparse_": SkipFunctionVariable,
|
||||
"torch.get_default_device": TorchInGraphFunctionVariable,
|
||||
# functorch/vmap
|
||||
@ -323,8 +325,6 @@ manual_torch_name_rule_map: dict[
|
||||
"torch._functorch.deprecated.grad_and_value": UserFunctionVariable,
|
||||
"torch._functorch.deprecated.vjp": UserFunctionVariable,
|
||||
# functorch/C++ bindings
|
||||
"torch._C._functorch._add_batch_dim": TorchInGraphFunctionVariable,
|
||||
"torch._C._functorch._remove_batch_dim": TorchInGraphFunctionVariable,
|
||||
"torch._C._functorch._wrap_for_grad": TorchInGraphFunctionVariable,
|
||||
"torch._C._functorch._unwrap_for_grad": TorchInGraphFunctionVariable,
|
||||
"torch._C._functorch._unwrap_batched": TorchInGraphFunctionVariable,
|
||||
@ -333,6 +333,8 @@ manual_torch_name_rule_map: dict[
|
||||
"torch._C._functorch.is_batchedtensor": TorchInGraphFunctionVariable,
|
||||
"torch._C._functorch.peek_interpreter_stack": TorchInGraphFunctionVariable,
|
||||
"torch._C._functorch.unwrap_if_dead": TorchInGraphFunctionVariable,
|
||||
"torch._functorch.predispatch._vmap_increment_nesting": TorchInGraphFunctionVariable,
|
||||
"torch._functorch.predispatch._vmap_decrement_nesting": TorchInGraphFunctionVariable,
|
||||
# everything else
|
||||
"torch._functorch.pyfunctorch.coerce_cinterpreter": TorchInGraphFunctionVariable,
|
||||
"torch._higher_order_ops.triton_kernel_wrap.do_prune_configs": UserFunctionVariable,
|
||||
@ -2364,7 +2366,11 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._functorch.utils.enable_single_level_autograd_function",
|
||||
"torch._functorch.utils.exposed_in",
|
||||
"torch._functorch.utils.unwrap_dead_wrappers",
|
||||
"torch._functorch.vmap.lazy_load_decompositions",
|
||||
"torch._functorch.predispatch.lazy_load_decompositions",
|
||||
"torch._functorch.predispatch._vmap_increment_nesting",
|
||||
"torch._functorch.predispatch._vmap_decrement_nesting",
|
||||
"torch._functorch.predispatch._add_batch_dim",
|
||||
"torch._functorch.predispatch._remove_batch_dim",
|
||||
"torch._guards.compile_context",
|
||||
"torch._guards.detect_fake_mode",
|
||||
"torch._guards.tracing",
|
||||
|
@ -2985,6 +2985,8 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
torch.seed,
|
||||
operator.mod,
|
||||
torch._functorch.vmap._validate_and_get_batch_size,
|
||||
torch._functorch.predispatch._vmap_increment_nesting,
|
||||
torch._functorch.predispatch._vmap_decrement_nesting,
|
||||
# some mac builds are missing torch.distributed.get_rank()
|
||||
getattr(torch.distributed, "get_rank", _missing),
|
||||
getattr(torch.distributed, "get_world_size", _missing),
|
||||
@ -3018,9 +3020,8 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
):
|
||||
set_example_value(proxy.node, example_value)
|
||||
return ConstantVariable.create(example_value, **options)
|
||||
elif (
|
||||
isinstance(example_value, (int, float, bool))
|
||||
and proxy.node.target is call_torchbind
|
||||
elif isinstance(example_value, (int, float, bool)) and (
|
||||
proxy.node.target is call_torchbind
|
||||
):
|
||||
set_example_value(proxy.node, example_value)
|
||||
return ConstantVariable.create(example_value, **options)
|
||||
|
@ -523,7 +523,7 @@ class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
||||
self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch._vmap_increment_nesting,
|
||||
torch._functorch.predispatch._vmap_increment_nesting,
|
||||
(batch_size_node, randomness),
|
||||
{},
|
||||
)
|
||||
@ -532,7 +532,10 @@ class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.cleanup()
|
||||
tx.output.create_node(
|
||||
"call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
|
||||
"call_function",
|
||||
torch._functorch.predispatch._vmap_decrement_nesting,
|
||||
(),
|
||||
{},
|
||||
)
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
|
@ -19,6 +19,7 @@ from torch._guards import detect_fake_mode
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
from torch.fx._utils import first_call_function_nn_module_stack
|
||||
from torch.fx.experimental.proxy_tensor import PreDispatchTorchFunctionMode
|
||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||
|
||||
|
||||
@ -211,6 +212,29 @@ def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> dict[str, Any]:
|
||||
return params_buffers_to_node_meta
|
||||
|
||||
|
||||
def _maybe_find_pre_dispatch_tf_mode_for_export():
|
||||
if not torch._C._is_torch_function_mode_enabled():
|
||||
return None
|
||||
|
||||
torch_function_mode_stack = torch.overrides._get_current_function_mode_stack()
|
||||
|
||||
pre_dispatch_tf_modes = [
|
||||
mode
|
||||
for mode in torch_function_mode_stack
|
||||
if isinstance(mode, PreDispatchTorchFunctionMode)
|
||||
]
|
||||
|
||||
assert len(pre_dispatch_tf_modes) <= 1, (
|
||||
f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}"
|
||||
)
|
||||
|
||||
if len(pre_dispatch_tf_modes) == 0:
|
||||
return None
|
||||
|
||||
mode = pre_dispatch_tf_modes[0]
|
||||
return mode
|
||||
|
||||
|
||||
def _populate_param_buffer_metadata_to_new_gm(
|
||||
params_buffers_to_node_meta: dict[str, Any],
|
||||
gm: torch.fx.GraphModule,
|
||||
|
@ -223,6 +223,11 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
torch.amp.autocast_mode._enter_autocast,
|
||||
torch.amp.autocast_mode._exit_autocast,
|
||||
torch.fx.experimental.symbolic_shapes.cast_symbool_to_symint_guardless,
|
||||
torch._functorch.predispatch._add_batch_dim,
|
||||
torch._functorch.predispatch._remove_batch_dim,
|
||||
torch._functorch.predispatch._vmap_increment_nesting,
|
||||
torch._functorch.predispatch._vmap_decrement_nesting,
|
||||
torch._functorch.predispatch.lazy_load_decompositions,
|
||||
)
|
||||
|
||||
if not isinstance(op, _allowed_op_types()):
|
||||
|
@ -4,6 +4,7 @@ from contextlib import contextmanager
|
||||
import torch
|
||||
import torch._custom_ops
|
||||
from torch._C import DispatchKey
|
||||
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
|
||||
from torch._higher_order_ops.flat_apply import (
|
||||
_ConstantFunction,
|
||||
flat_apply,
|
||||
@ -186,23 +187,12 @@ def mark_subclass_constructor_exportable_experimental(constructor_subclass):
|
||||
f"tensor subclass. Please look at DTensor.__init__ implementation as an example of proper usage of this API."
|
||||
)
|
||||
constructor_subclass(*args, **kwargs)
|
||||
if not torch._C._is_torch_function_mode_enabled():
|
||||
return
|
||||
torch_function_mode_stack = torch.overrides._get_current_function_mode_stack()
|
||||
|
||||
pre_dispatch_tf_modes = [
|
||||
mode
|
||||
for mode in torch_function_mode_stack
|
||||
if isinstance(mode, PreDispatchTorchFunctionMode)
|
||||
]
|
||||
assert len(pre_dispatch_tf_modes) <= 1, (
|
||||
f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}"
|
||||
)
|
||||
|
||||
if len(pre_dispatch_tf_modes) == 0:
|
||||
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
|
||||
if mode is None:
|
||||
return
|
||||
|
||||
mode = pre_dispatch_tf_modes[0]
|
||||
assert isinstance(mode, PreDispatchTorchFunctionMode)
|
||||
|
||||
tracer = mode.tracer
|
||||
subclass = args[0]
|
||||
|
158
torch/_functorch/predispatch.py
Normal file
158
torch/_functorch/predispatch.py
Normal file
@ -0,0 +1,158 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This module contains pre-dispatch wrappers for functorch operations
|
||||
that enable proper tracing in PT2 non-strict export/compile fx graph.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch._C._functorch import (
|
||||
_add_batch_dim as _add_batch_dim_impl,
|
||||
_remove_batch_dim as _remove_batch_dim_impl,
|
||||
_vmap_decrement_nesting as _vmap_decrement_nesting_impl,
|
||||
_vmap_increment_nesting as _vmap_increment_nesting_impl,
|
||||
)
|
||||
|
||||
|
||||
def _add_batch_dim(self, batch_dim, level):
|
||||
"""
|
||||
Thin wrapper around torch._C._add_batch_dim that is used to proxy in
|
||||
PT2 export/compile fx graph
|
||||
"""
|
||||
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
|
||||
|
||||
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
|
||||
|
||||
if mode:
|
||||
return torch.overrides.handle_torch_function(
|
||||
_add_batch_dim, (self,), self, batch_dim, level
|
||||
)
|
||||
|
||||
res = _add_batch_dim_impl(self, batch_dim, level)
|
||||
return res
|
||||
|
||||
|
||||
def _remove_batch_dim(self, level, batch_size, out_dim):
|
||||
"""
|
||||
Thin wrapper around torch._C._remove_batch_dim that is used to proxy in
|
||||
PT2 export/compile fx graph
|
||||
"""
|
||||
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
|
||||
|
||||
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
|
||||
|
||||
if mode:
|
||||
return torch.overrides.handle_torch_function(
|
||||
_remove_batch_dim, (self,), self, level, batch_size, out_dim
|
||||
)
|
||||
|
||||
res = _remove_batch_dim_impl(self, level, batch_size, out_dim)
|
||||
return res
|
||||
|
||||
|
||||
def _vmap_increment_nesting(batch_size, randomness):
|
||||
"""
|
||||
Thin wrapper around torch._C._vmap_increment_nesting that is used
|
||||
to proxy in export/compile graph
|
||||
"""
|
||||
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
|
||||
|
||||
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
|
||||
|
||||
if mode:
|
||||
return torch.overrides.handle_torch_function(
|
||||
_vmap_increment_nesting, (batch_size,), batch_size, randomness
|
||||
)
|
||||
res = _vmap_increment_nesting_impl(batch_size, randomness)
|
||||
return res
|
||||
|
||||
|
||||
def _vmap_decrement_nesting():
|
||||
"""
|
||||
Thin wrapper around torch._C._vmap_increment_nesting that is used
|
||||
to proxy in export/compile graph
|
||||
"""
|
||||
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
|
||||
|
||||
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
|
||||
|
||||
if mode:
|
||||
return torch.overrides.handle_torch_function(
|
||||
_vmap_decrement_nesting,
|
||||
(),
|
||||
)
|
||||
return _vmap_decrement_nesting_impl()
|
||||
|
||||
|
||||
# Global variables for lazy_load_decompositions
|
||||
DECOMPOSITIONS_LOADED = False
|
||||
DECOMPOSITIONS_LOCK = None # Will be initialized when needed
|
||||
VMAP_DECOMPOSITIONS_LIB = None
|
||||
|
||||
|
||||
def lazy_load_decompositions():
|
||||
"""
|
||||
Lazy loading of vmap decompositions with pre-dispatch support.
|
||||
"""
|
||||
from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
|
||||
|
||||
mode = _maybe_find_pre_dispatch_tf_mode_for_export()
|
||||
|
||||
if mode:
|
||||
return torch.overrides.handle_torch_function(lazy_load_decompositions, ())
|
||||
|
||||
global DECOMPOSITIONS_LOADED, DECOMPOSITIONS_LOCK, VMAP_DECOMPOSITIONS_LIB
|
||||
|
||||
if DECOMPOSITIONS_LOADED:
|
||||
return
|
||||
|
||||
# Initialize lock if needed
|
||||
if DECOMPOSITIONS_LOCK is None:
|
||||
import threading
|
||||
|
||||
DECOMPOSITIONS_LOCK = threading.Lock()
|
||||
|
||||
with DECOMPOSITIONS_LOCK:
|
||||
if DECOMPOSITIONS_LOADED:
|
||||
return
|
||||
|
||||
import os
|
||||
|
||||
if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
|
||||
DECOMPOSITIONS_LOADED = True
|
||||
return
|
||||
|
||||
# use an alternate way to register an operator into the decomposition table
|
||||
# _register_jit_decomposition doesn't work for some operators, e.g. addr,
|
||||
# because the Tensor types generated cannot be unioned by torchscript
|
||||
# decomp should be type OpOverload
|
||||
VMAP_DECOMPOSITIONS_LIB = torch.library.Library(
|
||||
"aten", "IMPL", "FuncTorchBatched"
|
||||
)
|
||||
|
||||
from torch._decomp import decomposition_table
|
||||
|
||||
def _register_python_decomposition_vmap(decomp):
|
||||
if decomp in decomposition_table:
|
||||
VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
|
||||
else:
|
||||
raise RuntimeError(f"could not find decomposition for {decomp}")
|
||||
|
||||
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
|
||||
_register_python_decomposition_vmap(
|
||||
torch.ops.aten.smooth_l1_loss_backward.default
|
||||
)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.addr.default)
|
||||
|
||||
DECOMPOSITIONS_LOADED = True
|
@ -9,19 +9,18 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import os
|
||||
import threading
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._C._functorch import (
|
||||
from torch._C._functorch import is_batchedtensor
|
||||
from torch._functorch.predispatch import (
|
||||
_add_batch_dim,
|
||||
_remove_batch_dim,
|
||||
_vmap_decrement_nesting,
|
||||
_vmap_increment_nesting,
|
||||
is_batchedtensor,
|
||||
lazy_load_decompositions,
|
||||
)
|
||||
from torch.utils._pytree import (
|
||||
_broadcast_to_and_flatten,
|
||||
@ -258,57 +257,6 @@ def _get_name(func: Callable):
|
||||
return repr(func)
|
||||
|
||||
|
||||
DECOMPOSITIONS_LOADED = False
|
||||
DECOMPOSITIONS_LOCK = threading.Lock()
|
||||
VMAP_DECOMPOSITIONS_LIB = None
|
||||
|
||||
|
||||
# torch.package, Python 3.11, and torch.jit-less environments are unhappy with
|
||||
# decompositions. Only load them when needed if possible.
|
||||
def lazy_load_decompositions():
|
||||
global DECOMPOSITIONS_LOADED
|
||||
if DECOMPOSITIONS_LOADED:
|
||||
return
|
||||
|
||||
with DECOMPOSITIONS_LOCK:
|
||||
if DECOMPOSITIONS_LOADED:
|
||||
return
|
||||
|
||||
if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
|
||||
DECOMPOSITIONS_LOADED = True
|
||||
return
|
||||
|
||||
# use an alternate way to register an operator into the decomposition table
|
||||
# _register_jit_decomposition doesn't work for some operators, e.g. addr,
|
||||
# because the Tensor types generated cannot be unioned by torchscript
|
||||
# decomp should be type OpOverload
|
||||
global VMAP_DECOMPOSITIONS_LIB
|
||||
VMAP_DECOMPOSITIONS_LIB = torch.library.Library(
|
||||
"aten", "IMPL", "FuncTorchBatched"
|
||||
)
|
||||
|
||||
from torch._decomp import decomposition_table
|
||||
|
||||
def _register_python_decomposition_vmap(decomp):
|
||||
if decomp in decomposition_table:
|
||||
VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
|
||||
else:
|
||||
raise RuntimeError(f"could not find decomposition for {decomp}")
|
||||
|
||||
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
|
||||
_register_python_decomposition_vmap(
|
||||
torch.ops.aten.smooth_l1_loss_backward.default
|
||||
)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default)
|
||||
_register_python_decomposition_vmap(torch.ops.aten.addr.default)
|
||||
|
||||
DECOMPOSITIONS_LOADED = True
|
||||
|
||||
|
||||
def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):
|
||||
lazy_load_decompositions()
|
||||
_check_out_dims_is_int_or_int_pytree(out_dims, func)
|
||||
|
@ -435,7 +435,6 @@ class Tracer(TracerBase):
|
||||
setattr(self.root, qualname, a)
|
||||
|
||||
return self.create_node("get_attr", qualname, (), {})
|
||||
|
||||
return super().create_arg(a)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
@ -817,6 +817,49 @@ def _maybe_record_pointwise_barrier(
|
||||
last_node.meta["low_precision_pointwise_barrier"] = True
|
||||
|
||||
|
||||
def _fetch_proxies_and_all_constant_flag(
|
||||
flat_args_kwargs: Union[list[object], tuple[object, ...]], tracer: _ProxyTracer
|
||||
) -> tuple[list[object], tuple[object, ...], bool]:
|
||||
"""
|
||||
Given flat arguments, fetch the proxies and whether they are all constants.
|
||||
This is later used in proxy_call or when someone is trying to stitch together
|
||||
graph node in tf or td modes.
|
||||
"""
|
||||
f_flat_args_kwargs = [
|
||||
(
|
||||
fetch_object_proxy(tracer, x)
|
||||
if isinstance(x, (Tensor, _AnyScriptObject))
|
||||
else x
|
||||
)
|
||||
for x in flat_args_kwargs
|
||||
]
|
||||
|
||||
# If there are SymInts, we also should not consider this constant.
|
||||
# However, fake tensor handling of SymInts is sufficiently broken that
|
||||
# I couldn't write a test for this case
|
||||
all_constant = (
|
||||
not any(
|
||||
t.constant is None
|
||||
for t in f_flat_args_kwargs
|
||||
if isinstance(t, _ProxyTensor)
|
||||
)
|
||||
# TODO: maybe constant SymInts should also be allowed? Not sure if
|
||||
# this can happen
|
||||
and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs)
|
||||
)
|
||||
|
||||
proxy_flat_args_kwargs = [
|
||||
e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs
|
||||
]
|
||||
|
||||
proxy_flat_args_kwargs = [
|
||||
(fetch_sym_proxy(tracer)(e) if isinstance(e, py_sym_types) else e)
|
||||
for e in proxy_flat_args_kwargs
|
||||
]
|
||||
|
||||
return f_flat_args_kwargs, tuple(proxy_flat_args_kwargs), all_constant
|
||||
|
||||
|
||||
def proxy_call(
|
||||
proxy_mode: ProxyTorchDispatchMode,
|
||||
func: OpOverload,
|
||||
@ -869,27 +912,8 @@ def proxy_call(
|
||||
return (args[0] != 0).item() # type: ignore[attr-defined]
|
||||
|
||||
tracer = proxy_mode.tracer
|
||||
f_flat_args_kwargs = [
|
||||
(
|
||||
fetch_object_proxy(tracer, x)
|
||||
if isinstance(x, (Tensor, _AnyScriptObject))
|
||||
else x
|
||||
)
|
||||
for x in flat_args_kwargs
|
||||
]
|
||||
|
||||
# If there are SymInts, we also should not consider this constant.
|
||||
# However, fake tensor handling of SymInts is sufficiently broken that
|
||||
# I couldn't write a test for this case
|
||||
all_constant = (
|
||||
not any(
|
||||
t.constant is None
|
||||
for t in f_flat_args_kwargs
|
||||
if isinstance(t, _ProxyTensor)
|
||||
)
|
||||
# TODO: maybe constant SymInts should also be allowed? Not sure if
|
||||
# this can happen
|
||||
and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs)
|
||||
f_flat_args_kwargs, proxy_flat_args_kwargs, all_constant = (
|
||||
_fetch_proxies_and_all_constant_flag(flat_args_kwargs, tracer)
|
||||
)
|
||||
|
||||
if torch.Tag.data_dependent_output in func.tags:
|
||||
@ -917,13 +941,6 @@ def proxy_call(
|
||||
"in your make_fx call."
|
||||
)
|
||||
|
||||
proxy_flat_args_kwargs = [
|
||||
e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs
|
||||
]
|
||||
proxy_flat_args_kwargs = [
|
||||
(fetch_sym_proxy(proxy_mode.tracer)(e) if isinstance(e, py_sym_types) else e)
|
||||
for e in proxy_flat_args_kwargs
|
||||
]
|
||||
proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec)
|
||||
|
||||
# When we trace through a torch.tensor invocation, you never actually
|
||||
@ -1435,6 +1452,27 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
||||
if func is torch._C._set_grad_enabled:
|
||||
func(*args, **kwargs)
|
||||
return node
|
||||
|
||||
# We need more complicated handling here because the inputs
|
||||
# to these functions are sometimes tensors or symints where
|
||||
# we need to fetch the proxies properly.
|
||||
if func in [
|
||||
torch._functorch.predispatch._add_batch_dim,
|
||||
torch._functorch.predispatch._remove_batch_dim,
|
||||
torch._functorch.predispatch._vmap_increment_nesting,
|
||||
torch._functorch.predispatch._vmap_decrement_nesting,
|
||||
torch._functorch.vmap.lazy_load_decompositions,
|
||||
]:
|
||||
_, proxies, _ = _fetch_proxies_and_all_constant_flag(args, self.tracer)
|
||||
out_proxy = self.tracer.create_proxy(
|
||||
"call_function",
|
||||
func,
|
||||
proxies,
|
||||
{},
|
||||
)
|
||||
res = func(*args, **kwargs)
|
||||
track_tensor_tree(res, out_proxy, constant=None, tracer=self.tracer)
|
||||
return res
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
|
@ -56,6 +56,7 @@ import torch.utils._pytree as pytree
|
||||
|
||||
# NB: The sym_* functions are used via getattr() and must be imported here.
|
||||
from torch import SymBool, SymFloat, SymInt
|
||||
from torch._C._functorch import get_unwrapped, is_batchedtensor
|
||||
from torch._guards import ShapeGuard, SLoc, Source, TracingContext
|
||||
from torch._logging import dtrace_structured, LazyString, structured, trace_structured
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
@ -1146,7 +1147,10 @@ def _free_unbacked_symbols_with_path(
|
||||
for attr in attrs:
|
||||
sub = getattr(a, attr)
|
||||
r.update(go(sub, path + (InnerTensorKey(attr),)))
|
||||
elif isinstance(a, torch.Tensor):
|
||||
elif isinstance(a, torch.Tensor) and is_batchedtensor(a):
|
||||
unwrapped_tensor = get_unwrapped(a)
|
||||
r.update(go(unwrapped_tensor, path))
|
||||
elif isinstance(a, torch.Tensor) and not is_batchedtensor(a):
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
assert isinstance(a, FakeTensor)
|
||||
|
Reference in New Issue
Block a user