Compare commits

...

72 Commits

Author SHA1 Message Date
74e5c3d921 Update
[ghstack-poisoned]
2025-11-14 00:17:43 -08:00
00db1dfc23 Update (base update)
[ghstack-poisoned]
2025-11-14 00:17:43 -08:00
2d586c736d Update
[ghstack-poisoned]
2025-11-12 17:14:00 -08:00
8ab1d524f7 Update (base update)
[ghstack-poisoned]
2025-11-12 17:14:00 -08:00
2750a7fee0 Update
[ghstack-poisoned]
2025-11-12 15:26:30 -08:00
2817633c36 Update
[ghstack-poisoned]
2025-11-12 15:16:41 -08:00
fb90c47154 Update (base update)
[ghstack-poisoned]
2025-11-12 13:05:28 -08:00
3fc9fb48ff Update
[ghstack-poisoned]
2025-11-12 13:05:28 -08:00
2039ee85db Update
[ghstack-poisoned]
2025-11-11 22:25:22 -08:00
e3fbec109f Update (base update)
[ghstack-poisoned]
2025-11-11 18:14:42 -08:00
0d8b37da5f Update
[ghstack-poisoned]
2025-11-11 18:14:42 -08:00
a7db673e53 Update (base update)
[ghstack-poisoned]
2025-11-09 22:14:31 -08:00
1d1f83dd92 Update
[ghstack-poisoned]
2025-11-09 22:14:31 -08:00
a21b825f19 Update (base update)
[ghstack-poisoned]
2025-11-08 23:01:02 -08:00
e507e07361 Update
[ghstack-poisoned]
2025-11-08 23:01:02 -08:00
6531323366 Update (base update)
[ghstack-poisoned]
2025-11-08 16:31:08 -08:00
373042f6f0 Update
[ghstack-poisoned]
2025-11-08 16:31:08 -08:00
390e9364af Update (base update)
[ghstack-poisoned]
2025-11-08 10:46:59 -08:00
68622b9c6e Update
[ghstack-poisoned]
2025-11-08 10:46:59 -08:00
33425d4a8e Update (base update)
[ghstack-poisoned]
2025-11-08 09:09:00 -08:00
4784498d05 Update
[ghstack-poisoned]
2025-11-08 09:09:00 -08:00
4be08d8c28 Update (base update)
[ghstack-poisoned]
2025-11-08 08:57:03 -08:00
b132cc8a68 Update
[ghstack-poisoned]
2025-11-08 08:57:03 -08:00
2a36b972d4 Update (base update)
[ghstack-poisoned]
2025-11-08 07:41:42 -08:00
c9f22e75ae Update
[ghstack-poisoned]
2025-11-08 07:41:42 -08:00
fe1afe16e3 Update (base update)
[ghstack-poisoned]
2025-11-07 22:55:01 -08:00
265a34a5eb Update
[ghstack-poisoned]
2025-11-07 22:55:01 -08:00
6486b54865 Update (base update)
[ghstack-poisoned]
2025-11-07 22:51:20 -08:00
0690194dd4 Update
[ghstack-poisoned]
2025-11-07 22:51:20 -08:00
87b47990a8 Update
[ghstack-poisoned]
2025-11-07 22:15:23 -08:00
395f9c3ca1 Update (base update)
[ghstack-poisoned]
2025-11-07 21:23:17 -08:00
c332695a95 Update
[ghstack-poisoned]
2025-11-07 21:23:17 -08:00
81ea254188 Update base for Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 18:47:39 -08:00
bf03db27c8 Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 18:47:39 -08:00
631b12e8ba Update base for Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 17:44:39 -08:00
18ef4c3128 Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 17:44:39 -08:00
4b2c463fd0 Update base for Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 16:01:55 -08:00
080c5d1331 Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 16:01:55 -08:00
e81aaecc29 Update base for Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 14:18:21 -08:00
a4b05b0738 Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 14:18:21 -08:00
53a946f7f9 Update base for Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 13:55:15 -08:00
f5d3f8901c Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 13:55:15 -08:00
7a832991fd Update base for Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 13:44:45 -08:00
dbf0853ecf Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 13:44:45 -08:00
a22b8ececc Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-02 08:22:41 -08:00
29af216024 Update on "[dynamo] Rehaul the autograd.Function support"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-11-01 16:21:03 -07:00
7b0c475bc4 [dynamo] Rehaul the autograd.Function support
[ghstack-poisoned]
2025-10-31 23:31:25 -07:00
b552a4eba1 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 16:43:15 -07:00
b3e120665b Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 16:43:15 -07:00
96a8d1c5e0 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:37:38 -07:00
39307c3db2 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:37:38 -07:00
3d6061d56a Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:34:54 -07:00
dc55769bb6 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:34:54 -07:00
2c74beddf6 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:33:32 -07:00
12ff17857e Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 13:33:32 -07:00
4ae3c59ce2 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 12:19:15 -07:00
7a8ad5f874 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-28 12:19:15 -07:00
dd09fa089d Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 15:37:16 -07:00
0995593caa Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 15:37:16 -07:00
69a4358a01 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 15:24:52 -07:00
0ab9e050ab Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 15:24:52 -07:00
651e9dbf94 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 14:08:25 -07:00
56bd4c695a Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 14:08:24 -07:00
1cb7be9419 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 13:34:38 -07:00
00f68803d3 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 13:34:38 -07:00
de1f732075 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 13:31:13 -07:00
cbfee32779 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 13:31:13 -07:00
0e38867920 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 12:57:54 -07:00
cefd269c35 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 12:57:54 -07:00
7fcf3a1488 Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 12:28:36 -07:00
9a88bd06e1 Update base for Update on "[DONT MERGE] Get rid of FUNCTION_MATCH"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-27 12:28:36 -07:00
ccc9750df1 [DONT MERGE] Get rid of FUNCTION_MATCH
[ghstack-poisoned]
2025-10-27 11:38:49 -07:00
11 changed files with 837 additions and 388 deletions

View File

@ -8,6 +8,7 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._dynamo.utils
from torch._dynamo.testing import AotEagerAndRecordGraphs
from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu
@ -604,22 +605,23 @@ class GraphModule(torch.nn.Module):
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None
return (autograd_function_apply,)
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_weird_b, l_weird_c, l_x_, l_z_, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_weird_b = l_weird_c = l_x_ = l_z_ = None
getitem: "f32[]" = autograd_function_apply[0]; autograd_function_apply = None
return (getitem,)
class fwd_body_0(torch.nn.Module):
def forward(self, ctx : torch.autograd.function.Function, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
def forward(self, l_weird_b: "f32[]", l_weird_c: "f32[]", l_x_: "f32[]", l_z_: "f32[]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
mul: "f32[]" = l_weird_b * l_weird_c
clone: "f32[]" = x.clone(); x = None
mul_1: "f32[]" = mul * clone; mul = clone = None
clone: "f32[]" = l_x_.clone(); l_x_ = None
outs: "f32[]" = mul * clone; mul = clone = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return (mul_1, [l_weird_b, l_weird_c])
return ((outs,), (l_weird_b, l_weird_c))
class bwd_body_0(torch.nn.Module):
def forward(self, ctx : torch.autograd.function.Function, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
def forward(self, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
mul: "f32[]" = grad * l_weird_b; l_weird_b = None
@ -627,7 +629,7 @@ class GraphModule(torch.nn.Module):
mul_2: "f32[]" = grad * 2; grad = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return (mul_1, mul_2)
return (None, None, mul_1, mul_2)
""",
)
@ -1125,32 +1127,33 @@ class GraphModule(torch.nn.Module):
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply: "f32[5, 4]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_weight_, args_tensor_mask = [True, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_weight_ = None
return (autograd_function_apply,)
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_weight_, l_x_, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_weight_ = l_x_ = None
getitem: "f32[5, 4]" = autograd_function_apply[0]; autograd_function_apply = None
return (getitem,)
class fwd_body_0(torch.nn.Module):
def forward(self, ctx : torch.autograd.function.Function, x: "f32[5, 3]", weight: "f32[4, 3]"):
def forward(self, l_weight_: "f32[4, 3]", l_x_: "f32[5, 3]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
t: "f32[3, 4]" = weight.t()
y: "f32[5, 4]" = x.matmul(t); t = None
t: "f32[3, 4]" = l_weight_.t()
y: "f32[5, 4]" = l_x_.matmul(t); t = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return (y, [weight, x])
return ((y,), (l_weight_, l_x_))
class bwd_body_0(torch.nn.Module):
def forward(self, function_ctx : torch.autograd.function.Function, y: "f32[5, 4]", weight: "f32[4, 3]", x: "f32[5, 3]"):
def forward(self, y: "f32[5, 4]", l_weight_: "f32[4, 3]", l_x_: "f32[5, 3]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
contiguous: "f32[5, 4]" = y.contiguous(); y = None
grad_x: "f32[5, 3]" = contiguous.matmul(weight); weight = None
grad_x: "f32[5, 3]" = contiguous.matmul(l_weight_); l_weight_ = None
transpose: "f32[4, 5]" = contiguous.transpose(0, 1); contiguous = None
grad_weight: "f32[4, 3]" = transpose.matmul(x); transpose = x = None
grad_weight: "f32[4, 3]" = transpose.matmul(l_x_); transpose = l_x_ = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return (grad_x, grad_weight)
return (grad_weight, grad_x)
""",
)
@ -1259,7 +1262,7 @@ class GraphModule(torch.nn.Module):
def foo(x):
return Foo.apply(x)
foo(torch.randn(2, requires_grad=True))
foo(torch.randn(2, requires_grad=True)).sum().backward()
self.assertEqual(cnts.frame_count, 1)
def test_mark_non_differentiable(self):
@ -1309,24 +1312,24 @@ class GraphModule(torch.nn.Module):
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, args_tensor_mask = [True, True], non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None
getitem: "f32[]" = autograd_function_apply[0]
getitem_1: "f32[]" = autograd_function_apply[1]; autograd_function_apply = None
return (getitem, getitem_1)
class fwd_body_0(torch.nn.Module):
def forward(self, ctx : torch.autograd.function.Function, x: "f32[]", y: "f32[]"):
def forward(self, l_x_: "f32[]", l_y_: "f32[]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
out1: "f32[]" = x.sin(); x = None
out1: "f32[]" = l_x_.sin(); l_x_ = None
out2: "f32[]" = y * 2; y = None
out2: "f32[]" = l_y_ * 2; l_y_ = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return ((out1, out2), [])
return ((out1, out2), ())
class bwd_body_0(torch.nn.Module):
def forward(self, ctx : torch.autograd.function.Function, grad1: "f32[]", grad2: "f32[]"):
def forward(self, grad1: "f32[]", grad2: "f32[]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
cos: "f32[]" = grad1.cos(); grad1 = None
@ -1438,7 +1441,7 @@ class GraphModule(torch.nn.Module):
# would be generated by autograd engine.
return result * 0.5
@torch.compile(backend="eager", fullgraph=True)
@torch.compile(backend="aot_eager", fullgraph=True)
def fn(x):
x, _ = MyCube.apply(x)
x, _ = MyCube.apply(x)
@ -1475,7 +1478,7 @@ class GraphModule(torch.nn.Module):
self.assertEqual(out, x + 1)
self.assertEqual(x.grad.shape, shape)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
self.assertEqual(cnt.op_count, 2)
@requires_gpu
def test_triton_kernel_basic(self):
@ -1543,6 +1546,238 @@ class GraphModule(torch.nn.Module):
loss.backward()
self.assertEqual(x + y, z)
def test_rewired_bwd_output(self):
class Add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
a = torch.sin(x)
b = torch.cos(y)
result = a * b
# Save input, output and intermediate to test all cases
ctx.save_for_backward(a, x, result)
return result, a + b
@staticmethod
def backward(ctx, grad_a, grad_b):
(a, x, result) = ctx.saved_tensors
return a * grad_b * 2 + x, result + grad_a * 3
def fn(x, y):
z = Add.apply(torch.cos(x), torch.cos(y))
return z[0] + z[1]
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, fullgraph=True, backend=backend)
x = torch.randn(8, 8, requires_grad=True)
y = torch.randn(8, 8, requires_grad=True)
x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
torch._dynamo.mark_dynamic(x_clone, 0)
torch._dynamo.mark_dynamic(y_clone, 0)
ref = fn(x, y)
res = opt_fn(x_clone, y_clone)
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
self.assertExpectedInline(
torch._dynamo.testing.normalize_gm(
backend.graphs[0].print_readable(print_output=False)
),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s17)", L_x_: "f32[s17, 8]", s17: "Sym(s17)", L_y_: "f32[s17, 8]"):
l_x_ = L_x_
l_y_ = L_y_
arg: "f32[s17, 8]" = torch.cos(l_x_); l_x_ = None
arg_1: "f32[s17, 8]" = torch.cos(l_y_); l_y_ = None
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, s77, arg, s17, arg_1, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = s77 = arg = s17 = arg_1 = None
getitem: "f32[s17, 8]" = autograd_function_apply[0]
getitem_1: "f32[s17, 8]" = autograd_function_apply[1]; autograd_function_apply = None
add: "f32[s17, 8]" = getitem + getitem_1; getitem = getitem_1 = None
return (add,)
class fwd_body_0(torch.nn.Module):
def forward(self, s77: "Sym(s17)", cos: "f32[s17, 8]", s17: "Sym(s17)", cos_1: "f32[s17, 8]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
a: "f32[s17, 8]" = torch.sin(cos)
b: "f32[s17, 8]" = torch.cos(cos_1); cos_1 = None
result: "f32[s17, 8]" = a * b
out: "f32[s17, 8]" = a + b; b = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return ((result, out), (s17, a, cos, result))
class bwd_body_0(torch.nn.Module):
def forward(self, grad_a: "f32[s17, 8]", grad_b: "f32[s17, 8]", s17: "Sym(s17)", a: "f32[s17, 8]", arg: "f32[s17, 8]", result: "f32[s17, 8]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
mul: "f32[s17, 8]" = a * grad_b; a = grad_b = None
mul_1: "f32[s17, 8]" = mul * 2; mul = None
add: "f32[s17, 8]" = mul_1 + arg; mul_1 = arg = None
mul_2: "f32[s17, 8]" = grad_a * 3; grad_a = None
add_1: "f32[s17, 8]" = result + mul_2; result = mul_2 = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return (None, add, None, add_1)
""",
)
def test_udf_output(self):
class Foo:
def __init__(self, a, b):
self.a = a
self.b = b
class Add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
a = torch.sin(x)
b = torch.cos(y)
ctx.save_for_backward(a)
return Foo(a, b), x * y
@staticmethod
def backward(ctx, grad_a, grad_b):
(a,) = ctx.saved_tensors
return grad_b * 2, grad_b * 3
def fn(x, y):
z = Add.apply(x, y)
return z[0].a + z[0].b + z[1]
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, fullgraph=True, backend=backend)
x = torch.randn(8, 8, requires_grad=True)
y = torch.randn(8, 8, requires_grad=True)
x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True)
ref = fn(x, y)
res = opt_fn(x_clone, y_clone)
ref.sum().backward()
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
self.assertExpectedInline(
torch._dynamo.testing.normalize_gm(
backend.graphs[0].print_readable(print_output=False)
),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[8, 8]", L_y_: "f32[8, 8]"):
l_x_ = L_x_
l_y_ = L_y_
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None
getitem: "f32[8, 8]" = autograd_function_apply[0]
getitem_1: "f32[8, 8]" = autograd_function_apply[1]
getitem_2: "f32[8, 8]" = autograd_function_apply[2]; autograd_function_apply = None
add: "f32[8, 8]" = getitem + getitem_1; getitem = getitem_1 = None
add_1: "f32[8, 8]" = add + getitem_2; add = getitem_2 = None
return (add_1,)
class fwd_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]", l_y_: "f32[8, 8]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
a: "f32[8, 8]" = torch.sin(l_x_)
b: "f32[8, 8]" = torch.cos(l_y_)
out: "f32[8, 8]" = l_x_ * l_y_; l_x_ = l_y_ = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return ((a, b, out), ())
class bwd_body_0(torch.nn.Module):
def forward(self, unused_0, unused_1, grad_b: "f32[8, 8]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
mul: "f32[8, 8]" = grad_b * 2
mul_1: "f32[8, 8]" = grad_b * 3; grad_b = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return (mul, mul_1)
""",
)
def test_aliasing_output(self):
class Add(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, grad_out):
return grad_out
def fn(x):
y = Add.apply(x)
if y is x:
return torch.cos(y)
return torch.sin(y)
x = torch.randn(8, 8, requires_grad=True)
ref = fn(x)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, fullgraph=True, backend=backend)
res = opt_fn(x)
self.assertEqual(ref, res)
# Must have `view_as`
self.assertExpectedInline(
torch._dynamo.testing.normalize_gm(
backend.graphs[0].print_readable(print_output=False)
),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[8, 8]"):
l_x_ = L_x_
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = None
y: "f32[8, 8]" = autograd_function_apply[0]; autograd_function_apply = None
sin: "f32[8, 8]" = torch.sin(y); y = None
return (sin,)
class fwd_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
view_as: "f32[8, 8]" = l_x_.view_as(l_x_); l_x_ = None
return ((view_as,), ())
class bwd_body_0(torch.nn.Module):
def forward(self, grad_out: "f32[8, 8]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return (grad_out,)
""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -10033,7 +10033,6 @@ def ___make_guard_fn():
# as the same across the two tracings. This is an unlikely situation in real use cases, so we add another
# `test_validate_outputs_unbacked_by_custom_op` to mitigate it and keep this one as expected failure
# until we have a proper fix.
@unittest.expectedFailure
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_validate_outputs_unbacked(self):
class SillyCat(torch.autograd.Function):

View File

@ -3796,7 +3796,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
expected = fn(*inputs1)
actual = fn_opt(*inputs2)
self.assertTrue(same(actual, expected))
self.assertEqual(cnt.op_count, 1)
self.assertEqual(cnt.op_count, 2)
self.assertEqual(cnt.frame_count, 1)
cnt.clear()
counters.clear()

View File

@ -2784,34 +2784,34 @@ def forward(self, add, tangents_1):
return (mul_1, None)""",
)
def test_backward_mutation_metadata(self):
class BwMutation(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(b)
return a.clone(), b.clone()
# def test_backward_mutation_metadata(self):
# class BwMutation(torch.autograd.Function):
# @staticmethod
# def forward(ctx, a, b):
# ctx.save_for_backward(b)
# return a.clone(), b.clone()
@staticmethod
def backward(ctx, grad_a, grad_b):
(b,) = ctx.saved_tensors
# bw metadata mutation
b.transpose_(1, 0)
return grad_a.clone(), grad_b.clone()
# @staticmethod
# def backward(ctx, grad_a, grad_b):
# (b,) = ctx.saved_tensors
# # bw metadata mutation
# b.transpose_(1, 0)
# return grad_a.clone(), grad_b.clone()
def f(a, b):
a_, b_ = BwMutation.apply(a, b)
out = a_ * b_
return out
# def f(a, b):
# a_, b_ = BwMutation.apply(a, b)
# out = a_ * b_
# return out
inp_no_grad = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=False),
]
# inp_no_grad = [
# torch.ones(3, 3, requires_grad=True),
# torch.ones(3, 3, requires_grad=False),
# ]
with self.assertRaisesRegex(
AssertionError, "input that had its metadata mutated in the backward"
):
self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
# with self.assertRaisesRegex(
# AssertionError, "input that had its metadata mutated in the backward"
# ):
# self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
def test_backward_mutation_on_grad_out(self):
class BwMutation(torch.autograd.Function):
@ -8681,7 +8681,7 @@ class MockFXGraphCache:
FAILING_CACHE_TESTS = (
# BypassAOTAutogradCache: unsupported nodes
"test_backward_mutation_data", # Custom Autograd Function
"test_backward_mutation_metadata", # Custom Autograd Function
# "test_backward_mutation_metadata", # Custom Autograd Function
"test_input_output_aliase_custom_autograd_function",
)

View File

@ -2048,27 +2048,28 @@ class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]"):
fwd_body_0 = self.fwd_body_0
bwd_body_0 = self.bwd_body_0
autograd_function_apply: "f32[8, 8]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, args_tensor_mask = [True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = None
return (autograd_function_apply,)
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = None
getitem: "f32[8, 8]" = autograd_function_apply[0]; autograd_function_apply = None
return (getitem,)
class fwd_body_0(torch.nn.Module):
def forward(self, ctx : torch.autograd.function.Function, x: "f32[8, 8]"):
def forward(self, l_x_: "f32[8, 8]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
sin: "f32[8, 8]" = torch.sin(x)
outs: "f32[8, 8]" = torch.sin(l_x_)
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return (sin, [x])
return ((outs,), (l_x_,))
class bwd_body_0(torch.nn.Module):
def forward(self, ctx : torch.autograd.function.Function, grad_out: "f32[8, 8]", x: "f32[8, 8]"):
def forward(self, grad_out: "f32[8, 8]", l_x_: "f32[8, 8]"):
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
cos: "f32[8, 8]" = torch.cos(grad_out); grad_out = None
mul: "f32[8, 8]" = x * cos; x = cos = None
mul: "f32[8, 8]" = l_x_ * cos; l_x_ = cos = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
return mul
return (mul,)
""",
)

View File

@ -3515,6 +3515,12 @@
}
],
"GB0346": [
{
"Gb_type": "autograd.Function.apply: non-function or method forward",
"Context": "str(fn)",
"Explanation": "Expected {method_name} to be a function or method.",
"Hints": []
},
{
"Gb_type": "autograd.Function.apply: non-function or method forward",
"Context": "str(self.fwd_graph)",
@ -3667,5 +3673,15 @@
"Use custom operators instead of direct attribute/method access."
]
}
],
"GB0363": [
{
"Gb_type": "Unsupported input type in backward graph of autograd.Function",
"Context": "Unsupported node type {type(example_value)}",
"Explanation": "Node {node} has example_value {example_value} which is not a Tensor/Symint/None",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
]
}

View File

@ -20,6 +20,7 @@ their semantic behavior.
"""
import contextlib
import copy
import functools
import inspect
import itertools
@ -251,6 +252,66 @@ def _make_inlined(tx: "InstructionTranslator", f):
return inline_call
def add_call_function(
tx: "InstructionTranslator",
fn: Any,
args: tuple[Any, ...],
kwargs: dict[str, Any],
flat_example_value: Any,
):
from .builder import wrap_fx_proxy
# Store the invocation as a call
flat_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn,
args=args,
kwargs=kwargs,
),
example_value=flat_example_value,
)
return flat_variable
def overwrite_tensor_vt_requires_grad(graph_output_vts, flat_variable):
# this is required for faithfully representing the autograd.Function forward
# outputs.
for orig_vt, subgraph_vt in zip(graph_output_vts, flat_variable.items):
if isinstance(orig_vt, (variables.SymNodeVariable, variables.TensorVariable)):
assert isinstance(
subgraph_vt, (variables.SymNodeVariable, variables.TensorVariable)
)
orig_vt.requires_grad = subgraph_vt.requires_grad
if orig_vt.requires_grad:
orig_vt.has_grad_fn = True
def overwrite_tensor_vt_proxy(graph_output_vts, flat_variable):
# wrap_fx_proxy creates fresh variable trackers. However, the main program
# after the speculate subgraph can still use the original tensor vts that
# are still pointing to the nodes present in the subgraph. So, we reproxify
# the original tensor vts with the subgraph outputs. This way, whenever the
# outer graph uses an original vt, it uses the subgraph output.
#
# This is critical for maintaining the separation between:
# - `body_r`: The output VT structure that Dynamo continues tracing (may
# contain non-proxyable objects, nested structures, etc.)
# - `graph_output_vts`: Only the tensor/symint VTs that were actual graph
# outputs from speculate_subgraph
#
# By overwriting the proxies of VTs in `body_r` with the proxies from the
# HOP call, we ensure the outer graph correctly references the HOP outputs
# while still allowing `body_r` to contain arbitrary Python objects.
for orig_vt, subgraph_vt in zip(graph_output_vts, flat_variable.items):
if isinstance(orig_vt, (variables.SymNodeVariable, variables.TensorVariable)):
assert isinstance(
subgraph_vt, (variables.SymNodeVariable, variables.TensorVariable)
)
orig_vt.proxy = subgraph_vt.proxy
def _call_function_with_auto_output_flattening(
tx: "InstructionTranslator",
fn: Any,
@ -282,44 +343,10 @@ def _call_function_with_auto_output_flattening(
Returns:
The body_r VT (unchanged), which Dynamo will continue tracing with
"""
from .builder import wrap_fx_proxy
# Store the invocation as a call
flat_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn,
args=args,
kwargs=kwargs,
),
example_value=flat_example_value,
)
# wrap_fx_proxy creates fresh variable trackers. However, the main program
# after the speculate subgraph can still use the original tensor vts that
# are still pointing to the nodes present in the subgraph. So, we reproxify
# the original tensor vts with the subgraph outputs. This way, whenever the
# outer graph uses an original vt, it uses the subgraph output.
#
# This is critical for maintaining the separation between:
# - `body_r`: The output VT structure that Dynamo continues tracing (may
# contain non-proxyable objects, nested structures, etc.)
# - `graph_output_vts`: Only the tensor/symint VTs that were actual graph
# outputs from speculate_subgraph
#
# By overwriting the proxies of VTs in `body_r` with the proxies from the
# HOP call, we ensure the outer graph correctly references the HOP outputs
# while still allowing `body_r` to contain arbitrary Python objects.
flat_variable = add_call_function(tx, fn, args, kwargs, flat_example_value)
if body_r is not None:
for orig_vt, subgraph_vt in zip(graph_output_vts, flat_variable.items):
if isinstance(
orig_vt, (variables.SymNodeVariable, variables.TensorVariable)
):
assert isinstance(
subgraph_vt, (variables.SymNodeVariable, variables.TensorVariable)
)
orig_vt.proxy = subgraph_vt.proxy
overwrite_tensor_vt_proxy(graph_output_vts, flat_variable)
return body_r
@ -823,16 +850,8 @@ def validate_args_and_maybe_create_graph_inputs(
if set_subgraph_inputs == "automatic":
args.append(a)
continue
elif set_subgraph_inputs == "semi_automatic":
if isinstance(a, AutogradFunctionContextVariable):
example_value = a.as_proxy().node.meta["example_value"]
arg_name = (
a.as_proxy().node.name
if sub_args_names is None
else sub_args_names[idx]
)
tracer.create_graph_input(arg_name, a.python_type(), example_value)
elif a.maybe_fx_node() is not None:
elif set_subgraph_inputs == "automatic_with_new_placeholder":
if isinstance(a, variables.TensorVariable):
node = a.maybe_fx_node()
example_value = node.meta["example_value"]
arg_name = (
@ -1197,7 +1216,7 @@ def speculate_subgraph_with_auto_output_flattening(
enable_grad: Optional[bool] = None,
# TODO - We can probably just make everyone use automatic for wrap_semantics
set_subgraph_inputs: Literal[
"automatic", "semi_automatic", "flatten_manual", "manual"
"automatic", "automatic_with_new_placeholder", "flatten_manual", "manual"
] = "automatic",
# Make default False
restore_side_effects: bool = True,
@ -1288,7 +1307,7 @@ def speculate_subgraph_with_auto_output_flattening(
assert set_subgraph_inputs in {
"automatic",
"semi_automatic",
"automatic_with_new_placeholder",
"flatten_manual",
"manual",
}, "Please use one of the supported set_subgraph_inputs options."
@ -1514,7 +1533,7 @@ def speculate_subgraph(
assert set_subgraph_inputs in {
"automatic",
"semi_automatic",
"automatic_with_new_placeholder",
"flatten_manual",
"manual",
}, "Please use one of the supported set_subgraph_inputs options."
@ -3675,10 +3694,10 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
class AutogradFunctionApplyVariable(VariableTracker):
def __init__(self, fwd_graph, bwd_graph, parent_source, **kwargs) -> None:
def __init__(self, fwd_fn, bwd_fn, parent_source, **kwargs) -> None:
super().__init__(**kwargs)
self.fwd_graph = fwd_graph
self.bwd_graph = bwd_graph
self.fwd_fn = fwd_fn
self.bwd_fn = bwd_fn
self.parent_source = parent_source
def call_function(
@ -3687,52 +3706,172 @@ class AutogradFunctionApplyVariable(VariableTracker):
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
from . import (
AutogradFunctionContextVariable,
UserDefinedClassVariable,
UserFunctionVariable,
UserMethodVariable,
)
from .builder import wrap_fx_proxy
"""
Consider the following:
At the highest level, the goal of tracing an autograd.Function is to
essentially emit a new autograd.Function object. To do this, Dynamo
traces fwd and bwd graph and then inserts a AutogradFunctionApply HOP in
the graph that call the traced fwd and bwd graph in the `forward` and
`backward` methods respectively. AOTDispatcher desugars this HOP and
just inlines the hop fwd and bwd into the main graph during its tracing.
However, the traced forward and backward graphs cannot be directly
placed in the new autograd.Function because autograd.Function has some
requirements.
a) # fwd graph inputs = # bwd graph outputs
b) # fwd graph outputs = # bwd graph inputs
c) Since the graphs do not have ctx variable, we have to manually return
the saved_tensors from the forward and have additional inputs in the
backward, and wire the connections.
Unfortunately, reworking the initial traced fwd and bwd graphs to
satisfy the above 3 conditions leads to a very tedious codebase.
Lets look at an example
class Foo:
def __init__(self):
self.a = 4
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
def forward(ctx, x, foo):
ctx.save_for_backward(x)
return x.sin()
return x.sin() + foo.a
@staticmethod
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad * x.cos()
We want the resulting graphs to look like:
def fwd(ctx, x):
# Note that Dynamo lifts the foo_a directly as an input.
def fwd(ctx, x, foo_a):
# (output, saved tensors / attrs)
return (x.sin(), [x])
# bwd(ctx, grad0, grad1, ..., gradn, *saved_tensors_or_attrs)
return (x.sin() + foo_a, (x))
# Note that backward graph has None as the second output to match the
# fwd requirements (even though the original backward function has just
# output)
def bwd(ctx, grad, x):
return grad * x.cos()
return grad * x.cos(), None
To accomplish this, we're going to:
1. Construct a ctx object
2. (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph on MySin.forward (manually_set_inputs=True)
3. (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph on MySin.backward, while manually setting
the ctx and grad inputs.
4. Manually rewriting the fwd graph's output to be (output, stuff_that_gets_used in bwd_graph)
Getting from 3 to 4 is pretty elegant: stuff_that_gets_used in bwd graph is
just the bwd_freevars returned from speculate_subgraph, assuming MySin.backward
doesn't capture any arguments.
All these steps work if MySin.backward doesn't capture any values. This is a
limitation in general that we should check for.
2. Speculate subgraph forward
3. Speculate subgraph backward
4. rewired_bwd_graph_inputs - Use the traced fwd graph as the anchor point, and rewire the backward graph outputs
5. handle_saved_tensors_wiring - Hhandle the saved tensors, as mentioned in (c)
"""
prev_side_effects = tx.output.side_effects.clone()
fwd_tracer = torch._dynamo.output_graph.SubgraphTracer(
tx.output,
parent=tx.output.current_tracer,
source_target="autograd.Function",
)
ctx = self.prepare_ctx_vt(tx, args, kwargs)
fwd_fn, fwd_out, fwd_graph, fwd_freevars, fwd_graph_output_vts = (
self.trace_forward_graph(tx, ctx, fwd_tracer, args, kwargs)
)
bwd_args, bwd_out, bwd_graph, bwd_freevars, bwd_graph_output_vts = (
self.trace_backward_graph(tx, ctx, fwd_tracer, fwd_out, fwd_fn)
)
# At this point, the fwd_out represents the output of the forward
# method. fwd_graph represents the tensor computation, its input and
# output do not match the original forward method. Same is true for the
# bwd_out and bwd_graph. However, in order to create a new
# autograd.Function to pass to the lower compiler, the fwd and bwd graph
# must be "consistent".
self.rewire_bwd_graph_inputs(
fwd_out, fwd_graph, fwd_freevars, bwd_out, bwd_graph, bwd_freevars, args
)
fwd_graph, bwd_graph = self.handle_saved_tensors_wiring(
ctx,
fwd_out,
fwd_graph,
fwd_freevars,
fwd_graph_output_vts,
bwd_out,
bwd_graph,
bwd_freevars,
bwd_args,
)
# If users call ctx.mark_non_differentiable, we should capture these output tensors who
# are marked as non-differentiable and pass them to ApplyTemplate
# at torch._functorch.autograd_function.AutogradFunctionApply for reconstruction.
non_differentiable_idx = []
if ctx.non_differentiable is not None:
non_differentiable_set = set(ctx.non_differentiable)
assert isinstance(fwd_out, variables.BaseListVariable)
for i, x in enumerate(fwd_out.items):
if (
isinstance(x, variables.TensorVariable)
and x.as_proxy() in non_differentiable_set
):
non_differentiable_idx.append(i)
# Store fwd_body
fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
fwd_name = tx.output.install_subgraph(
"fwd_body",
torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph),
)
fwd_node = make_attr(tx, fwd_name)
# Store bwd_body
bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
bwd_name = tx.output.install_subgraph(
"bwd_body",
torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph),
)
bwd_node = make_attr(tx, bwd_name)
p_args = (
fwd_node,
bwd_node,
*list(fwd_freevars.keys()),
)
kwargs = {
"non_differentiable_idx": non_differentiable_idx,
}
# Store the invocation as a call
from torch._functorch.autograd_function import autograd_function_apply
# We use speculate_subgraph to get the fwd graph, but it's always under no grad mode like what eager mode does.
# The fwd outputs (tensor's example_value) need to be inferred from fake tensor prop to get the correct attributes
# (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing.
# Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it.
with enable_python_dispatcher():
with tx.output.fake_mode:
fwd_freevars_args = [_get_fake_value(arg) for arg in fwd_freevars]
fake_args = (
tx.output.nn_modules[fwd_node.node.name],
tx.output.nn_modules[bwd_node.node.name],
*fwd_freevars_args,
)
example_value = autograd_function_apply(*fake_args, **kwargs)
flat_variable = add_call_function(
tx, autograd_function_apply, p_args, kwargs, example_value
)
overwrite_tensor_vt_proxy(fwd_graph_output_vts, flat_variable)
overwrite_tensor_vt_requires_grad(fwd_graph_output_vts, flat_variable)
return fwd_out
def prepare_ctx_vt(self, tx, args, kwargs):
from . import AutogradFunctionContextVariable
ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
with discard_graph_changes(tx):
# A little hacky, but we need a dummy ctx proxy for speculate_subgraph.
@ -3742,37 +3881,39 @@ class AutogradFunctionApplyVariable(VariableTracker):
)
set_example_value(proxy.node, ctx.value)
ctx.proxy = proxy
return ctx
if isinstance(self.fwd_graph, types.FunctionType):
fwd_fn = UserFunctionVariable(self.fwd_graph)
fwd_args = [ctx, *args]
elif isinstance(self.fwd_graph, types.MethodType):
fwd_fn = UserMethodVariable(
self.fwd_graph.__func__,
UserDefinedClassVariable(self.fwd_graph.__class__),
)
fwd_args = [fwd_fn.obj, ctx, *args]
else:
unimplemented(
gb_type="autograd.Function.apply: non-function or method forward",
context=str(self.fwd_graph),
explanation="Expected forward function to be a function or method.",
hints=[],
)
def trace_forward_graph(self, tx, ctx, fwd_tracer, args, kwargs):
from torch._functorch.autograd_function import autograd_function_trace_helper
fwd_fn, fwd_args = self.prepare_fn_vt(ctx, "forward", args)
# autograd.Function forward does a few things like running in no_grad
# mode and also appying view_as for input tensors that are returned as
# outputs. Therefore, we wrap the original forward in a helper that have
# those extra bits for Dynamo to trace.
fwd_fn = _make_inlined(tx, autograd_function_trace_helper)(fwd_fn)
# Speculate subgraph on the fwd
(fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph(
tx,
fwd_fn,
fwd_args,
kwargs,
"autograd.Function",
enable_grad=False,
set_subgraph_inputs="semi_automatic",
restore_side_effects=False,
tracer=fwd_tracer,
fwd_out, fwd_graph, fwd_freevars, fwd_graph_output_vts = (
speculate_subgraph_with_auto_output_flattening(
tx,
fwd_fn,
fwd_args,
kwargs,
"autograd.Function",
enable_grad=None,
set_subgraph_inputs="automatic",
restore_side_effects=False,
tracer=fwd_tracer,
)
)
# For inputs that are not used but need to be captured, so that the gradient is tracked.
for arg in args:
if isinstance(arg, variables.TensorVariable):
fwd_tracer.maybe_lift_tracked_freevar_to_input(arg.as_proxy())
if ctx in tx.output.side_effects.store_attr_mutations:
if (
"_materialize_non_diff_grads"
@ -3786,39 +3927,38 @@ class AutogradFunctionApplyVariable(VariableTracker):
*graph_break_hints.SUPPORTABLE,
],
)
return fwd_fn, fwd_out, fwd_graph, fwd_freevars, fwd_graph_output_vts
def trace_backward_graph(self, tx, ctx, fwd_tracer, fwd_out, fwd_fn):
from . import UserDefinedClassVariable, UserFunctionVariable, UserMethodVariable
# Note that for the forward, we do not restore side effects, because we
# want the later tracing to see the side-effects. But for backward, we
# are just trying to capture the graph, and therefore we must restore
# the side effects.
prev_side_effects = tx.output.side_effects
# Speculate subgraph on the backward. We make the bwd tracer a child of
# the fwd tracer, because backward may rely on tensors/attrs created in
# the fwd tracer.
bwd_tracer = torch._dynamo.output_graph.SubgraphTracer(
tx.output,
parent=fwd_tracer,
source_target="autograd.Function",
)
# Speculate subgraph on the backward. We make the
# bwd tracer a child of the fwd tracer, because backward may rely on
# tensors/attrs created in the fwd tracer.
if isinstance(fwd_out, variables.BaseListVariable):
bwd_args = [ctx, *fwd_out.items]
bwd_args = []
if isinstance(fwd_out, variables.TensorVariable):
bwd_args.append(fwd_out)
else:
bwd_args = [ctx, fwd_out]
assert isinstance(fwd_out, variables.BaseListVariable)
for i in fwd_out.items:
if isinstance(i, variables.TensorVariable):
bwd_args.append(i)
else:
bwd_args.append(ConstantVariable.create(None))
bwd_src = AttrSource(self.parent_source, member="backward")
if isinstance(self.bwd_graph, types.FunctionType):
bwd_fn = UserFunctionVariable(self.bwd_graph, source=bwd_src)
elif isinstance(self.bwd_graph, types.MethodType):
bwd_fn = UserMethodVariable(
self.bwd_graph.__func__,
UserDefinedClassVariable(self.bwd_graph.__class__),
source=bwd_src,
)
bwd_args = [bwd_fn.obj, *bwd_args]
else:
unimplemented(
gb_type="autograd.Function.apply: non-function or method backward",
context=str(self.bwd_graph),
explanation="Expected backward function to be a function or method.",
hints=[],
)
bwd_fn, bwd_args = self.prepare_fn_vt(ctx, "backward", bwd_args)
def is_strict_for(v: VariableTracker):
if isinstance(v, variables.TensorVariable):
@ -3826,21 +3966,27 @@ class AutogradFunctionApplyVariable(VariableTracker):
return v.proxy.tracer is not fwd_tracer
return True
# automatic with new placeholder relies on the function arg names to
# create a new proxy. Also, it will always INSERT a tensor placeholder
# as input, even though it might not be used in the graph. This allows
# us to make a mapping for the backward graph.
with (
tx.output.subtracer(fwd_fn, fwd_tracer),
tx.strict_translation_mode(is_strict_for),
):
try:
(bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph(
tx,
bwd_fn,
bwd_args,
kwargs,
"autograd.Function",
enable_grad=False,
set_subgraph_inputs="manual",
restore_side_effects=False,
tracer=bwd_tracer,
bwd_out, bwd_graph, bwd_freevars, bwd_graph_output_vts = (
speculate_subgraph_with_auto_output_flattening(
tx,
bwd_fn,
bwd_args,
{},
"autograd.Function",
enable_grad=False,
set_subgraph_inputs="automatic_with_new_placeholder",
restore_side_effects=False,
tracer=bwd_tracer,
)
)
except torch._dynamo.exc.Unsupported as e:
if isinstance(
@ -3857,16 +4003,14 @@ class AutogradFunctionApplyVariable(VariableTracker):
autograd_function_backward_rewritten,
)
if isinstance(self.bwd_graph, types.FunctionType):
if isinstance(self.bwd_fn, types.FunctionType):
bwd_fn = UserFunctionVariable(
autograd_function_backward_rewritten(self.bwd_graph)
autograd_function_backward_rewritten(self.bwd_fn)
)
elif isinstance(self.bwd_graph, types.MethodType):
elif isinstance(self.bwd_fn, types.MethodType):
bwd_fn = UserMethodVariable(
autograd_function_backward_rewritten(
self.bwd_graph.__func__
),
UserDefinedClassVariable(self.bwd_graph.__class__),
autograd_function_backward_rewritten(self.bwd_fn.__func__),
UserDefinedClassVariable(self.bwd_fn.__class__),
)
else:
unimplemented(
@ -3880,190 +4024,215 @@ class AutogradFunctionApplyVariable(VariableTracker):
"torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops",
[],
):
(bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph(
tx,
bwd_fn,
bwd_args,
kwargs,
"autograd.Function",
enable_grad=False,
set_subgraph_inputs="manual",
restore_side_effects=False,
tracer=bwd_tracer,
bwd_out, bwd_graph, bwd_freevars, bwd_graph_output_vts = (
speculate_subgraph_with_auto_output_flattening(
tx,
bwd_fn,
bwd_args,
{},
"autograd.Function",
enable_grad=False,
set_subgraph_inputs="automatic_with_new_placeholder",
restore_side_effects=False,
tracer=bwd_tracer,
)
)
else:
raise e
tx.output.side_effects = prev_side_effects
return bwd_args, bwd_out, bwd_graph, bwd_freevars, bwd_graph_output_vts
# TODO: assert that bwd_graph didn't capture values that were
# not created inside fwd_graph.
def rewire_bwd_graph_inputs(
self,
fwd_out,
fwd_graph,
fwd_freevars,
bwd_out,
bwd_graph,
bwd_freevars,
orig_fwd_args,
):
# Ensure fwd-input and bwd-output consistency - autograd.Function
# requires that the inputs of the forward line up correctly with the
# outputs of the backward. This is the responsibily of the user. Now,
# when Dynamo is creating a new autograd.Function, it is now Dynamo
# responsibility to do this lineup. To do this, we use the original user
# point as the anchor point to provide this mapping.
# TODO(oulgen): Ideally, we would not do a linear search for output
# node but as things currently are there could be nodes after the
# output node
# This is bug prone as if there's code after the output node, then
# graph.output will append the output at the very end
# This might be a behavior difference
# Some more description to understand the following codebase
#
# fwd_freevars/bwd_freevars: A map from outer graph proxy to inner graph
# placeholder proxy. The keys are ALWAYS outer graph proxy, these could
# be inputs in the main graph or also intermediates in the main graph
# that are passed as inputs to the subgraph.
#
# orig_fwd_args - Variable trackers for the forward graph inputs. Since
# these are inputs, the tensor variables here are all OUTER graph proxies.
# If users call ctx.mark_non_differentiable, we should capture these output tensors who
# are marked as non-differentiable and pass them to ApplyTemplate
# at torch._functorch.autograd_function.AutogradFunctionApply for reconstruction.
non_differentiable_idx = []
if ctx.non_differentiable is not None:
non_differentiable_set = set(ctx.non_differentiable)
assert isinstance(fwd_out, variables.BaseListVariable)
for i, x in enumerate(fwd_out.items):
if (
isinstance(x, variables.TensorVariable)
and x.as_proxy() in non_differentiable_set
):
non_differentiable_idx.append(i)
# bwd_outs - Variable trackers for the backward output. Since these are
# output, the variable trackers here point to INNER graph proxies. The
# special case is when an input is passed directly to the output of the
# backward graph, in which case, the variable tracker can still point to
# the outer graph proxy.
# Rewrite the output of fwd_graph to (output, stuff_necessary_for_bwd)
for node in fwd_graph.find_nodes(op="output"):
fwd_graph.erase_node(node)
break
# To make the fwd-inputs and bwd-outputs consistent, we just rewire the
# backward graph outputs to match with the forward graph inputs. To do
# this, we first rely on orig_fwd_args and bwd_outs to make a mapping of
# outer_proxy to inner graph proxy. And walk through the fwd_graph
# inputs and this map to find the bwd outputs.
# Because we lift the bwd_freevars as inputs of the bwd_graph,
# we have to manually add the bwd_freevars as output of fwd_graph.
# However, the bwd_freevars got from speculate_subgraph use the Proxies in the bwd_graph,
# we need to convert them to Proxies in the fwd_graph and then generate new fwd_graph output.
fwd_proxy_of_bwd_freevars = []
for k in bwd_freevars:
if k in fwd_freevars:
fwd_proxy_of_bwd_freevars.append(fwd_freevars[k])
else:
fwd_proxy_of_bwd_freevars.append(k)
def get_bwd_node(vt):
# Backward tensor vt here can be - (1) an intermediate, or (2) input
# to the backward graph. If it is an input to the backward graph, we have to lookup bwd_freevars to get the inner proxy.
return bwd_freevars.get(vt.proxy, vt.proxy).node
def unwrap_proxy(x):
if isinstance(x, torch.fx.Proxy):
return x.node
else:
assert variables.ConstantVariable.is_literal(x), (
f"Only constant is allowed. Got {x}"
)
return x
# Find the mapping between orig_fwd_args and bwd_out
outer_fwd_proxy_to_bwd_node = {}
if isinstance(bwd_out, variables.BaseListVariable):
bwd_outs = bwd_out.items
for idx, fwd_arg in enumerate(orig_fwd_args):
# We care about tensor args. For non-tensor args, the bwd output returns None.
if isinstance(fwd_arg, variables.TensorVariable):
bwd_out_at_idx = bwd_outs[idx]
if isinstance(bwd_out_at_idx, variables.TensorVariable):
outer_fwd_proxy_to_bwd_node[fwd_arg.proxy] = get_bwd_node(
bwd_out_at_idx
)
else:
# backward can return None at the output
assert (
isinstance(bwd_out_at_idx, variables.ConstantVariable)
and bwd_out_at_idx.value is None
)
outer_fwd_proxy_to_bwd_node[fwd_arg.proxy] = None
new_fwd_graph_outputs = (fwd_out.as_proxy(), fwd_proxy_of_bwd_freevars)
new_fwd_graph_outputs = pytree.tree_map(unwrap_proxy, new_fwd_graph_outputs)
fwd_graph.output(new_fwd_graph_outputs)
fwd_graph.lint()
elif isinstance(bwd_out, variables.TensorVariable):
outer_fwd_proxy_to_bwd_node[orig_fwd_args[0].proxy] = get_bwd_node(bwd_out)
# Store fwd_body
fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
fwd_name = tx.output.install_subgraph(
"fwd_body",
torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph),
)
# Ideally, we should have walked through the fwd placeholders. But we
# can instead walk through the fwd_freevars, which is a insertion sorted
# dictionary and therefore represents the outer_proxies for the
# placeholder in the same order as that as placeholders.
rewired_bwd_outputs = [
outer_fwd_proxy_to_bwd_node.get(fwd_proxy) for fwd_proxy in fwd_freevars
]
fwd_node = make_attr(tx, fwd_name)
# The type of original args can be arbitrary, but we only support basic type in FX graph.
# So the speculated subgraph input includes original tensor args and the lifted freevars.
# We need to filter out the original tensor args and concat them with the lifted freevars
# to generate the proxy args for the FX call_function node.
filtered_args = []
# A boolean list to mark if the type of corresponding argument is tensor.
# This is used to determine if a FX node's argument should be an argument of
# ApplyTemplate.forward and if we should skip the output from ApplyTemplate.backward
# at torch._functorch.autograd_function.AutogradFunctionApply.
args_tensor_mask = [False] * len(args)
for i, arg in enumerate(args):
if isinstance(arg, (variables.TensorVariable, variables.SymNodeVariable)):
filtered_args.append(arg)
args_tensor_mask[i] = True
# Rewrite the output of bwd_graph to remove the grad output for the non-Tensor args.
new_bwd_graph_outputs = None
for node in bwd_graph.find_nodes(op="output"):
bwd_graph.erase_node(node)
break
# The same as the above fwd proxies, we need to use the bwd proxies in the bwd_graph
# if some of the output is from fwd_freevars.
bwd_out_proxy = bwd_out.as_proxy()
bwd_proxy_of_fwd_freevars = []
if isinstance(bwd_out_proxy, (tuple, list)):
for k in bwd_out_proxy:
if k in bwd_freevars:
bwd_proxy_of_fwd_freevars.append(bwd_freevars[k])
else:
bwd_proxy_of_fwd_freevars.append(k)
else:
if bwd_out_proxy in bwd_freevars:
bwd_proxy_of_fwd_freevars = bwd_freevars[bwd_out_proxy]
else:
bwd_proxy_of_fwd_freevars = bwd_out_proxy
# Remove bwd output for non-Tensor args.
output_proxy = bwd_proxy_of_fwd_freevars
if isinstance(output_proxy, (tuple, list)):
new_bwd_graph_outputs = ()
for x, mask in zip(output_proxy, args_tensor_mask):
if mask:
new_bwd_graph_outputs = new_bwd_graph_outputs + (x,)
else:
assert x is None, f"Grad of non-Tensor arg {x} is not None."
else:
new_bwd_graph_outputs = output_proxy
# Update the bwd graph output.
new_bwd_graph_outputs = pytree.tree_map(
lambda x: None if x is None else x.node, new_bwd_graph_outputs
)
bwd_graph.output(new_bwd_graph_outputs)
bwd_graph.output(tuple(rewired_bwd_outputs))
bwd_graph.lint()
# Store bwd_body
bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
bwd_name = tx.output.install_subgraph(
"bwd_body",
torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph),
)
def handle_saved_tensors_wiring(
self,
ctx,
fwd_out,
fwd_graph,
fwd_freevars,
fwd_graph_body_outputs,
bwd_out,
bwd_graph,
bwd_freevars,
bwd_args,
):
# First we need to map the existing forward graph outputs to bwd graph inputs.
bwd_input_nodes = list(bwd_graph.find_nodes(op="placeholder"))
bwd_node = make_attr(tx, bwd_name)
fwd_vt_to_bwd_node = {}
bwd_idx = 0
if isinstance(fwd_out, variables.BaseListVariable):
for fwd_vt in fwd_out.items:
if isinstance(fwd_vt, variables.TensorVariable):
fwd_vt_to_bwd_node[fwd_vt] = bwd_input_nodes[bwd_idx]
bwd_idx += 1
else:
if isinstance(fwd_out, variables.TensorVariable):
fwd_vt_to_bwd_node[fwd_out] = bwd_input_nodes[bwd_idx]
bwd_idx += 1
tx.output.side_effects = prev_side_effects
rewired_bwd_graph_inputs = []
for fwd_graph_vt in fwd_graph_body_outputs:
rewired_bwd_graph_inputs.append(fwd_vt_to_bwd_node.get(fwd_graph_vt))
p_args = (
fwd_node,
bwd_node,
*([arg.as_proxy() for arg in filtered_args] + list(fwd_freevars.keys())),
)
kwargs = {
"args_tensor_mask": args_tensor_mask,
"non_differentiable_idx": non_differentiable_idx,
}
# bwd_freevars also contains the symint passed from forward to backward
extra_fwd_output_nodes = []
for fwd_proxy, bwd_inner_proxy in bwd_freevars.items():
# For backward, its easy, just get the node from bwd_inner_proxy
rewired_bwd_graph_inputs.append(bwd_inner_proxy.node)
# Store the invocation as a call
from torch._functorch.autograd_function import autograd_function_apply
# For the fwd_proxy, it could be a proxy from the outer graph, or it
# could be an intermediate.
# First ensure that's its inner fwd proxy
inner_fwd_proxy = fwd_freevars.get(fwd_proxy, fwd_proxy)
# We use speculate_subgraph to get the fwd graph, but it's always under no grad mode like what eager mode does.
# The fwd outputs (tensor's example_value) need to be inferred from fake tensor prop to get the correct attributes
# (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing.
# Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it.
with enable_python_dispatcher(), tx.output.fake_mode:
fake_args = (
tx.output.nn_modules[fwd_node.node.name],
tx.output.nn_modules[bwd_node.node.name],
*(
[
_get_fake_value(arg)
for arg in filtered_args + list(fwd_freevars.keys())
]
),
extra_fwd_output_nodes.append(inner_fwd_proxy.node)
# We have all the info, lets change the fwd graph
fwd_output_nodes = []
for node in fwd_graph.find_nodes(op="output"):
fwd_output_nodes = node.args[0]
fwd_graph.erase_node(node)
break
new_fwd_graph_outputs = (fwd_output_nodes, tuple(extra_fwd_output_nodes))
fwd_graph.output(new_fwd_graph_outputs)
fwd_graph.lint()
# Now lets change the bwd graph.
new_graph = torch.fx.Graph()
env = {}
count = itertools.count()
for node in rewired_bwd_graph_inputs:
if node is None:
new_node = new_graph.placeholder(f"unused_{next(count)}")
else:
new_node = new_graph.placeholder(node.name)
new_node.meta = copy.copy(node.meta)
env[node] = new_node
for node in bwd_graph.nodes:
if node.op == "placeholder":
assert node in env
else:
env[node] = new_graph.node_copy(node, lambda x: env[x])
env[node].meta = copy.copy(node.meta)
new_graph.lint()
return fwd_graph, new_graph
def prepare_fn_vt(self, ctx, method_name, args):
from . import UserDefinedClassVariable, UserFunctionVariable, UserMethodVariable
source = None
if self.parent_source:
source = AttrSource(self.parent_source, member=method_name)
if method_name == "forward":
fn = self.fwd_fn
else:
fn = self.bwd_fn
if isinstance(fn, types.FunctionType):
fn_vt = UserFunctionVariable(fn, source=source)
fn_args = [ctx, *args]
elif isinstance(fn, types.MethodType):
cls_vt = UserDefinedClassVariable(fn.__class__)
fn_vt = UserMethodVariable(
fn.__func__,
cls_vt,
source=source,
)
example_value = autograd_function_apply(*fake_args, **kwargs)
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
autograd_function_apply,
args=p_args,
kwargs=kwargs,
),
example_value=example_value,
)
fn_args = [cls_vt, ctx, *args]
else:
unimplemented(
gb_type="autograd.Function.apply: non-function or method forward",
context=str(fn),
explanation=f"Expected {method_name} to be a function or method.",
hints=[],
)
return fn_vt, fn_args
def _get_fake_value(x):

View File

@ -701,6 +701,7 @@ class AutogradFunctionVariable(VariableTracker):
VariableTracker.visit(visit, (args, kwargs))
if requires_grad and torch.is_grad_enabled():
source = self.source
if config.capture_autograd_function is False:
warnings.warn(
"The config.capture_autograd_function flag is deprecated, it's now always true."
@ -720,6 +721,10 @@ class AutogradFunctionVariable(VariableTracker):
forward_fn = autograd_function_forward_rewritten(
self.fn_cls.forward, self.fn_cls.setup_context
)
# The forward points to a new function now, so we can't use the
# old source. Later on, we guard specifically on
# is_setup_ctx_defined
source = None
vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
if vjp_fn is not torch.autograd.Function.vjp:
@ -752,29 +757,23 @@ class AutogradFunctionVariable(VariableTracker):
from .higher_order_ops import AutogradFunctionApplyVariable
source = self.source
if source is None:
if source is None and not is_setup_ctx_defined:
source = AttrSource(
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
)
apply_source = source and AttrSource(source, member="apply")
val = AutogradFunctionApplyVariable(
forward_fn,
self.fn_cls.backward,
source,
source=AttrSource(source, member="apply"),
source=apply_source,
).call_function(tx, args, kwargs)
# Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping
# the forward function, as we don't want to generate guards for new_forward.__closure__
# if forward is rewritten by autograd_function_forward_rewritten.
# But we still need to generate correct guards for the original forward and setup_context
# functions, so we have to add guards manually.
if self.source:
if self.source and is_setup_ctx_defined:
fwd_src = AttrSource(self.source, "forward")
install_guard(fwd_src.make_guard(GuardBuilder.CLOSURE_MATCH))
if is_setup_ctx_defined:
setup_ctx_src = AttrSource(self.source, "setup_context")
install_guard(setup_ctx_src.make_guard(GuardBuilder.CLOSURE_MATCH))
setup_ctx_src = AttrSource(self.source, "setup_context")
install_guard(setup_ctx_src.make_guard(GuardBuilder.CLOSURE_MATCH))
return val

View File

@ -743,20 +743,14 @@ class AutogradFunctionApply(HigherOrderOperator):
def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
saved_values = None
args_tensor_mask = fwd_kwargs["args_tensor_mask"]
non_differentiable_idx = fwd_kwargs["non_differentiable_idx"]
length_of_tensor_args = sum(args_tensor_mask)
# Filter out the original tensor args from fwd_args,
# lifted freevars should not be args of ApplyTemplate.apply
# since we don't need to calculate the gradients of them.
new_fwd_args = fwd_args[:length_of_tensor_args]
class ApplyTemplate(torch.autograd.Function):
@staticmethod
# pyrefly: ignore [bad-override]
def forward(ctx, *args):
nonlocal saved_values
output, saved_values = fwd(None, *fwd_args)
output, saved_values = fwd(*args)
# If users call ctx.mark_non_differentiable() in the original fwd function.
if len(non_differentiable_idx) > 0:
@ -770,9 +764,45 @@ class AutogradFunctionApply(HigherOrderOperator):
@staticmethod
def backward(ctx, *grad):
return bwd(None, *grad, *saved_values)
return bwd(*grad, *saved_values)
return ApplyTemplate.apply(*new_fwd_args)
return ApplyTemplate.apply(*fwd_args)
autograd_function_apply = AutogradFunctionApply()
# autograd.Function forward does more than just running the forward method. Most
# of this logic is in C++. Here, we rewrite that functionality in python and let
# Dynamo trace it. This is most probably incomplete.
def autograd_function_trace_helper(orig_fwd):
def inner(*args, **kwargs):
with torch.no_grad():
outs = orig_fwd(*args, **kwargs)
# Handle the case where if the input is passed on directly to the output, we call view_as
# Refer to https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/custom_function.cpp#L254
tensor_args = {arg for arg in args if isinstance(arg, torch.Tensor)}
if isinstance(outs, torch.Tensor):
if outs in tensor_args:
return outs.view_as(outs)
else:
return outs
new_outs = []
for out in outs:
if isinstance(out, torch.Tensor):
if out in tensor_args:
new_outs.append(out.view_as(out))
else:
new_outs.append(out)
else:
new_outs.append(out)
return tuple(new_outs)
# TODO - there is missing functionality here, where autograd.Function
# overwrites the requires_grad_ of the output tensors depending on the
# `mark_non_differentiable`. Currently, this is handled hackily in
# Dynamo, where we just overwrite the variable trackers requires_grad.
return inner