mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-20 02:24:54 +08:00
Compare commits
72 Commits
update-vll
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| 74e5c3d921 | |||
| 00db1dfc23 | |||
| 2d586c736d | |||
| 8ab1d524f7 | |||
| 2750a7fee0 | |||
| 2817633c36 | |||
| fb90c47154 | |||
| 3fc9fb48ff | |||
| 2039ee85db | |||
| e3fbec109f | |||
| 0d8b37da5f | |||
| a7db673e53 | |||
| 1d1f83dd92 | |||
| a21b825f19 | |||
| e507e07361 | |||
| 6531323366 | |||
| 373042f6f0 | |||
| 390e9364af | |||
| 68622b9c6e | |||
| 33425d4a8e | |||
| 4784498d05 | |||
| 4be08d8c28 | |||
| b132cc8a68 | |||
| 2a36b972d4 | |||
| c9f22e75ae | |||
| fe1afe16e3 | |||
| 265a34a5eb | |||
| 6486b54865 | |||
| 0690194dd4 | |||
| 87b47990a8 | |||
| 395f9c3ca1 | |||
| c332695a95 | |||
| 81ea254188 | |||
| bf03db27c8 | |||
| 631b12e8ba | |||
| 18ef4c3128 | |||
| 4b2c463fd0 | |||
| 080c5d1331 | |||
| e81aaecc29 | |||
| a4b05b0738 | |||
| 53a946f7f9 | |||
| f5d3f8901c | |||
| 7a832991fd | |||
| dbf0853ecf | |||
| a22b8ececc | |||
| 29af216024 | |||
| 7b0c475bc4 | |||
| b552a4eba1 | |||
| b3e120665b | |||
| 96a8d1c5e0 | |||
| 39307c3db2 | |||
| 3d6061d56a | |||
| dc55769bb6 | |||
| 2c74beddf6 | |||
| 12ff17857e | |||
| 4ae3c59ce2 | |||
| 7a8ad5f874 | |||
| dd09fa089d | |||
| 0995593caa | |||
| 69a4358a01 | |||
| 0ab9e050ab | |||
| 651e9dbf94 | |||
| 56bd4c695a | |||
| 1cb7be9419 | |||
| 00f68803d3 | |||
| de1f732075 | |||
| cbfee32779 | |||
| 0e38867920 | |||
| cefd269c35 | |||
| 7fcf3a1488 | |||
| 9a88bd06e1 | |||
| ccc9750df1 |
@ -8,6 +8,7 @@ import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch._dynamo.utils
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs
|
||||
from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu
|
||||
|
||||
|
||||
@ -604,22 +605,23 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None
|
||||
return (autograd_function_apply,)
|
||||
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_weird_b, l_weird_c, l_x_, l_z_, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_weird_b = l_weird_c = l_x_ = l_z_ = None
|
||||
getitem: "f32[]" = autograd_function_apply[0]; autograd_function_apply = None
|
||||
return (getitem,)
|
||||
|
||||
class fwd_body_0(torch.nn.Module):
|
||||
def forward(self, ctx : torch.autograd.function.Function, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
|
||||
def forward(self, l_weird_b: "f32[]", l_weird_c: "f32[]", l_x_: "f32[]", l_z_: "f32[]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
mul: "f32[]" = l_weird_b * l_weird_c
|
||||
clone: "f32[]" = x.clone(); x = None
|
||||
mul_1: "f32[]" = mul * clone; mul = clone = None
|
||||
clone: "f32[]" = l_x_.clone(); l_x_ = None
|
||||
outs: "f32[]" = mul * clone; mul = clone = None
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return (mul_1, [l_weird_b, l_weird_c])
|
||||
return ((outs,), (l_weird_b, l_weird_c))
|
||||
|
||||
class bwd_body_0(torch.nn.Module):
|
||||
def forward(self, ctx : torch.autograd.function.Function, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
|
||||
def forward(self, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
mul: "f32[]" = grad * l_weird_b; l_weird_b = None
|
||||
@ -627,7 +629,7 @@ class GraphModule(torch.nn.Module):
|
||||
mul_2: "f32[]" = grad * 2; grad = None
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return (mul_1, mul_2)
|
||||
return (None, None, mul_1, mul_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1125,32 +1127,33 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply: "f32[5, 4]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_weight_, args_tensor_mask = [True, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_weight_ = None
|
||||
return (autograd_function_apply,)
|
||||
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_weight_, l_x_, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_weight_ = l_x_ = None
|
||||
getitem: "f32[5, 4]" = autograd_function_apply[0]; autograd_function_apply = None
|
||||
return (getitem,)
|
||||
|
||||
class fwd_body_0(torch.nn.Module):
|
||||
def forward(self, ctx : torch.autograd.function.Function, x: "f32[5, 3]", weight: "f32[4, 3]"):
|
||||
def forward(self, l_weight_: "f32[4, 3]", l_x_: "f32[5, 3]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
t: "f32[3, 4]" = weight.t()
|
||||
y: "f32[5, 4]" = x.matmul(t); t = None
|
||||
t: "f32[3, 4]" = l_weight_.t()
|
||||
y: "f32[5, 4]" = l_x_.matmul(t); t = None
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return (y, [weight, x])
|
||||
return ((y,), (l_weight_, l_x_))
|
||||
|
||||
class bwd_body_0(torch.nn.Module):
|
||||
def forward(self, function_ctx : torch.autograd.function.Function, y: "f32[5, 4]", weight: "f32[4, 3]", x: "f32[5, 3]"):
|
||||
def forward(self, y: "f32[5, 4]", l_weight_: "f32[4, 3]", l_x_: "f32[5, 3]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
contiguous: "f32[5, 4]" = y.contiguous(); y = None
|
||||
|
||||
grad_x: "f32[5, 3]" = contiguous.matmul(weight); weight = None
|
||||
grad_x: "f32[5, 3]" = contiguous.matmul(l_weight_); l_weight_ = None
|
||||
|
||||
transpose: "f32[4, 5]" = contiguous.transpose(0, 1); contiguous = None
|
||||
grad_weight: "f32[4, 3]" = transpose.matmul(x); transpose = x = None
|
||||
grad_weight: "f32[4, 3]" = transpose.matmul(l_x_); transpose = l_x_ = None
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return (grad_x, grad_weight)
|
||||
return (grad_weight, grad_x)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -1259,7 +1262,7 @@ class GraphModule(torch.nn.Module):
|
||||
def foo(x):
|
||||
return Foo.apply(x)
|
||||
|
||||
foo(torch.randn(2, requires_grad=True))
|
||||
foo(torch.randn(2, requires_grad=True)).sum().backward()
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_mark_non_differentiable(self):
|
||||
@ -1309,24 +1312,24 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, args_tensor_mask = [True, True], non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None
|
||||
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None
|
||||
getitem: "f32[]" = autograd_function_apply[0]
|
||||
getitem_1: "f32[]" = autograd_function_apply[1]; autograd_function_apply = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class fwd_body_0(torch.nn.Module):
|
||||
def forward(self, ctx : torch.autograd.function.Function, x: "f32[]", y: "f32[]"):
|
||||
def forward(self, l_x_: "f32[]", l_y_: "f32[]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
out1: "f32[]" = x.sin(); x = None
|
||||
out1: "f32[]" = l_x_.sin(); l_x_ = None
|
||||
|
||||
out2: "f32[]" = y * 2; y = None
|
||||
out2: "f32[]" = l_y_ * 2; l_y_ = None
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return ((out1, out2), [])
|
||||
return ((out1, out2), ())
|
||||
|
||||
class bwd_body_0(torch.nn.Module):
|
||||
def forward(self, ctx : torch.autograd.function.Function, grad1: "f32[]", grad2: "f32[]"):
|
||||
def forward(self, grad1: "f32[]", grad2: "f32[]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
cos: "f32[]" = grad1.cos(); grad1 = None
|
||||
@ -1438,7 +1441,7 @@ class GraphModule(torch.nn.Module):
|
||||
# would be generated by autograd engine.
|
||||
return result * 0.5
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
@torch.compile(backend="aot_eager", fullgraph=True)
|
||||
def fn(x):
|
||||
x, _ = MyCube.apply(x)
|
||||
x, _ = MyCube.apply(x)
|
||||
@ -1475,7 +1478,7 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(out, x + 1)
|
||||
self.assertEqual(x.grad.shape, shape)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, 1)
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
|
||||
@requires_gpu
|
||||
def test_triton_kernel_basic(self):
|
||||
@ -1543,6 +1546,238 @@ class GraphModule(torch.nn.Module):
|
||||
loss.backward()
|
||||
self.assertEqual(x + y, z)
|
||||
|
||||
def test_rewired_bwd_output(self):
|
||||
class Add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(y)
|
||||
result = a * b
|
||||
# Save input, output and intermediate to test all cases
|
||||
ctx.save_for_backward(a, x, result)
|
||||
return result, a + b
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_a, grad_b):
|
||||
(a, x, result) = ctx.saved_tensors
|
||||
return a * grad_b * 2 + x, result + grad_a * 3
|
||||
|
||||
def fn(x, y):
|
||||
z = Add.apply(torch.cos(x), torch.cos(y))
|
||||
return z[0] + z[1]
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend=backend)
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
y = torch.randn(8, 8, requires_grad=True)
|
||||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
y_clone = y.detach().clone().requires_grad_(True)
|
||||
torch._dynamo.mark_dynamic(x_clone, 0)
|
||||
torch._dynamo.mark_dynamic(y_clone, 0)
|
||||
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x_clone, y_clone)
|
||||
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
self.assertExpectedInline(
|
||||
torch._dynamo.testing.normalize_gm(
|
||||
backend.graphs[0].print_readable(print_output=False)
|
||||
),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s17)", L_x_: "f32[s17, 8]", s17: "Sym(s17)", L_y_: "f32[s17, 8]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
arg: "f32[s17, 8]" = torch.cos(l_x_); l_x_ = None
|
||||
arg_1: "f32[s17, 8]" = torch.cos(l_y_); l_y_ = None
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, s77, arg, s17, arg_1, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = s77 = arg = s17 = arg_1 = None
|
||||
getitem: "f32[s17, 8]" = autograd_function_apply[0]
|
||||
getitem_1: "f32[s17, 8]" = autograd_function_apply[1]; autograd_function_apply = None
|
||||
|
||||
add: "f32[s17, 8]" = getitem + getitem_1; getitem = getitem_1 = None
|
||||
return (add,)
|
||||
|
||||
class fwd_body_0(torch.nn.Module):
|
||||
def forward(self, s77: "Sym(s17)", cos: "f32[s17, 8]", s17: "Sym(s17)", cos_1: "f32[s17, 8]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
a: "f32[s17, 8]" = torch.sin(cos)
|
||||
|
||||
b: "f32[s17, 8]" = torch.cos(cos_1); cos_1 = None
|
||||
|
||||
result: "f32[s17, 8]" = a * b
|
||||
|
||||
out: "f32[s17, 8]" = a + b; b = None
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return ((result, out), (s17, a, cos, result))
|
||||
|
||||
class bwd_body_0(torch.nn.Module):
|
||||
def forward(self, grad_a: "f32[s17, 8]", grad_b: "f32[s17, 8]", s17: "Sym(s17)", a: "f32[s17, 8]", arg: "f32[s17, 8]", result: "f32[s17, 8]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
mul: "f32[s17, 8]" = a * grad_b; a = grad_b = None
|
||||
mul_1: "f32[s17, 8]" = mul * 2; mul = None
|
||||
add: "f32[s17, 8]" = mul_1 + arg; mul_1 = arg = None
|
||||
mul_2: "f32[s17, 8]" = grad_a * 3; grad_a = None
|
||||
add_1: "f32[s17, 8]" = result + mul_2; result = mul_2 = None
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return (None, add, None, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_udf_output(self):
|
||||
class Foo:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
class Add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
a = torch.sin(x)
|
||||
b = torch.cos(y)
|
||||
ctx.save_for_backward(a)
|
||||
return Foo(a, b), x * y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_a, grad_b):
|
||||
(a,) = ctx.saved_tensors
|
||||
return grad_b * 2, grad_b * 3
|
||||
|
||||
def fn(x, y):
|
||||
z = Add.apply(x, y)
|
||||
return z[0].a + z[0].b + z[1]
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend=backend)
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
y = torch.randn(8, 8, requires_grad=True)
|
||||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
y_clone = y.detach().clone().requires_grad_(True)
|
||||
|
||||
ref = fn(x, y)
|
||||
res = opt_fn(x_clone, y_clone)
|
||||
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(x.grad, x_clone.grad)
|
||||
|
||||
self.assertExpectedInline(
|
||||
torch._dynamo.testing.normalize_gm(
|
||||
backend.graphs[0].print_readable(print_output=False)
|
||||
),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[8, 8]", L_y_: "f32[8, 8]"):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None
|
||||
getitem: "f32[8, 8]" = autograd_function_apply[0]
|
||||
getitem_1: "f32[8, 8]" = autograd_function_apply[1]
|
||||
getitem_2: "f32[8, 8]" = autograd_function_apply[2]; autograd_function_apply = None
|
||||
|
||||
add: "f32[8, 8]" = getitem + getitem_1; getitem = getitem_1 = None
|
||||
add_1: "f32[8, 8]" = add + getitem_2; add = getitem_2 = None
|
||||
return (add_1,)
|
||||
|
||||
class fwd_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8, 8]", l_y_: "f32[8, 8]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
a: "f32[8, 8]" = torch.sin(l_x_)
|
||||
|
||||
b: "f32[8, 8]" = torch.cos(l_y_)
|
||||
|
||||
out: "f32[8, 8]" = l_x_ * l_y_; l_x_ = l_y_ = None
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return ((a, b, out), ())
|
||||
|
||||
class bwd_body_0(torch.nn.Module):
|
||||
def forward(self, unused_0, unused_1, grad_b: "f32[8, 8]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
mul: "f32[8, 8]" = grad_b * 2
|
||||
mul_1: "f32[8, 8]" = grad_b * 3; grad_b = None
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return (mul, mul_1)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_aliasing_output(self):
|
||||
class Add(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return grad_out
|
||||
|
||||
def fn(x):
|
||||
y = Add.apply(x)
|
||||
if y is x:
|
||||
return torch.cos(y)
|
||||
return torch.sin(y)
|
||||
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
|
||||
ref = fn(x)
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend=backend)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
# Must have `view_as`
|
||||
self.assertExpectedInline(
|
||||
torch._dynamo.testing.normalize_gm(
|
||||
backend.graphs[0].print_readable(print_output=False)
|
||||
),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[8, 8]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = None
|
||||
y: "f32[8, 8]" = autograd_function_apply[0]; autograd_function_apply = None
|
||||
|
||||
sin: "f32[8, 8]" = torch.sin(y); y = None
|
||||
return (sin,)
|
||||
|
||||
class fwd_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8, 8]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
|
||||
view_as: "f32[8, 8]" = l_x_.view_as(l_x_); l_x_ = None
|
||||
return ((view_as,), ())
|
||||
|
||||
class bwd_body_0(torch.nn.Module):
|
||||
def forward(self, grad_out: "f32[8, 8]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return (grad_out,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -10033,7 +10033,6 @@ def ___make_guard_fn():
|
||||
# as the same across the two tracings. This is an unlikely situation in real use cases, so we add another
|
||||
# `test_validate_outputs_unbacked_by_custom_op` to mitigate it and keep this one as expected failure
|
||||
# until we have a proper fix.
|
||||
@unittest.expectedFailure
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_validate_outputs_unbacked(self):
|
||||
class SillyCat(torch.autograd.Function):
|
||||
|
||||
@ -3796,7 +3796,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
expected = fn(*inputs1)
|
||||
actual = fn_opt(*inputs2)
|
||||
self.assertTrue(same(actual, expected))
|
||||
self.assertEqual(cnt.op_count, 1)
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
cnt.clear()
|
||||
counters.clear()
|
||||
|
||||
@ -2784,34 +2784,34 @@ def forward(self, add, tangents_1):
|
||||
return (mul_1, None)""",
|
||||
)
|
||||
|
||||
def test_backward_mutation_metadata(self):
|
||||
class BwMutation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
ctx.save_for_backward(b)
|
||||
return a.clone(), b.clone()
|
||||
# def test_backward_mutation_metadata(self):
|
||||
# class BwMutation(torch.autograd.Function):
|
||||
# @staticmethod
|
||||
# def forward(ctx, a, b):
|
||||
# ctx.save_for_backward(b)
|
||||
# return a.clone(), b.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_a, grad_b):
|
||||
(b,) = ctx.saved_tensors
|
||||
# bw metadata mutation
|
||||
b.transpose_(1, 0)
|
||||
return grad_a.clone(), grad_b.clone()
|
||||
# @staticmethod
|
||||
# def backward(ctx, grad_a, grad_b):
|
||||
# (b,) = ctx.saved_tensors
|
||||
# # bw metadata mutation
|
||||
# b.transpose_(1, 0)
|
||||
# return grad_a.clone(), grad_b.clone()
|
||||
|
||||
def f(a, b):
|
||||
a_, b_ = BwMutation.apply(a, b)
|
||||
out = a_ * b_
|
||||
return out
|
||||
# def f(a, b):
|
||||
# a_, b_ = BwMutation.apply(a, b)
|
||||
# out = a_ * b_
|
||||
# return out
|
||||
|
||||
inp_no_grad = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
torch.ones(3, 3, requires_grad=False),
|
||||
]
|
||||
# inp_no_grad = [
|
||||
# torch.ones(3, 3, requires_grad=True),
|
||||
# torch.ones(3, 3, requires_grad=False),
|
||||
# ]
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, "input that had its metadata mutated in the backward"
|
||||
):
|
||||
self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
|
||||
# with self.assertRaisesRegex(
|
||||
# AssertionError, "input that had its metadata mutated in the backward"
|
||||
# ):
|
||||
# self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)
|
||||
|
||||
def test_backward_mutation_on_grad_out(self):
|
||||
class BwMutation(torch.autograd.Function):
|
||||
@ -8681,7 +8681,7 @@ class MockFXGraphCache:
|
||||
FAILING_CACHE_TESTS = (
|
||||
# BypassAOTAutogradCache: unsupported nodes
|
||||
"test_backward_mutation_data", # Custom Autograd Function
|
||||
"test_backward_mutation_metadata", # Custom Autograd Function
|
||||
# "test_backward_mutation_metadata", # Custom Autograd Function
|
||||
"test_input_output_aliase_custom_autograd_function",
|
||||
)
|
||||
|
||||
|
||||
@ -2048,27 +2048,28 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[8, 8]"):
|
||||
fwd_body_0 = self.fwd_body_0
|
||||
bwd_body_0 = self.bwd_body_0
|
||||
autograd_function_apply: "f32[8, 8]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, args_tensor_mask = [True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = None
|
||||
return (autograd_function_apply,)
|
||||
autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = None
|
||||
getitem: "f32[8, 8]" = autograd_function_apply[0]; autograd_function_apply = None
|
||||
return (getitem,)
|
||||
|
||||
class fwd_body_0(torch.nn.Module):
|
||||
def forward(self, ctx : torch.autograd.function.Function, x: "f32[8, 8]"):
|
||||
def forward(self, l_x_: "f32[8, 8]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
sin: "f32[8, 8]" = torch.sin(x)
|
||||
outs: "f32[8, 8]" = torch.sin(l_x_)
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return (sin, [x])
|
||||
return ((outs,), (l_x_,))
|
||||
|
||||
class bwd_body_0(torch.nn.Module):
|
||||
def forward(self, ctx : torch.autograd.function.Function, grad_out: "f32[8, 8]", x: "f32[8, 8]"):
|
||||
def forward(self, grad_out: "f32[8, 8]", l_x_: "f32[8, 8]"):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
||||
|
||||
cos: "f32[8, 8]" = torch.cos(grad_out); grad_out = None
|
||||
mul: "f32[8, 8]" = x * cos; x = cos = None
|
||||
mul: "f32[8, 8]" = l_x_ * cos; l_x_ = cos = None
|
||||
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
||||
return mul
|
||||
return (mul,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@ -3515,6 +3515,12 @@
|
||||
}
|
||||
],
|
||||
"GB0346": [
|
||||
{
|
||||
"Gb_type": "autograd.Function.apply: non-function or method forward",
|
||||
"Context": "str(fn)",
|
||||
"Explanation": "Expected {method_name} to be a function or method.",
|
||||
"Hints": []
|
||||
},
|
||||
{
|
||||
"Gb_type": "autograd.Function.apply: non-function or method forward",
|
||||
"Context": "str(self.fwd_graph)",
|
||||
@ -3667,5 +3673,15 @@
|
||||
"Use custom operators instead of direct attribute/method access."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0363": [
|
||||
{
|
||||
"Gb_type": "Unsupported input type in backward graph of autograd.Function",
|
||||
"Context": "Unsupported node type {type(example_value)}",
|
||||
"Explanation": "Node {node} has example_value {example_value} which is not a Tensor/Symint/None",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ their semantic behavior.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
@ -251,6 +252,66 @@ def _make_inlined(tx: "InstructionTranslator", f):
|
||||
return inline_call
|
||||
|
||||
|
||||
def add_call_function(
|
||||
tx: "InstructionTranslator",
|
||||
fn: Any,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
flat_example_value: Any,
|
||||
):
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
# Store the invocation as a call
|
||||
flat_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
fn,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
),
|
||||
example_value=flat_example_value,
|
||||
)
|
||||
return flat_variable
|
||||
|
||||
|
||||
def overwrite_tensor_vt_requires_grad(graph_output_vts, flat_variable):
|
||||
# this is required for faithfully representing the autograd.Function forward
|
||||
# outputs.
|
||||
for orig_vt, subgraph_vt in zip(graph_output_vts, flat_variable.items):
|
||||
if isinstance(orig_vt, (variables.SymNodeVariable, variables.TensorVariable)):
|
||||
assert isinstance(
|
||||
subgraph_vt, (variables.SymNodeVariable, variables.TensorVariable)
|
||||
)
|
||||
orig_vt.requires_grad = subgraph_vt.requires_grad
|
||||
if orig_vt.requires_grad:
|
||||
orig_vt.has_grad_fn = True
|
||||
|
||||
|
||||
def overwrite_tensor_vt_proxy(graph_output_vts, flat_variable):
|
||||
# wrap_fx_proxy creates fresh variable trackers. However, the main program
|
||||
# after the speculate subgraph can still use the original tensor vts that
|
||||
# are still pointing to the nodes present in the subgraph. So, we reproxify
|
||||
# the original tensor vts with the subgraph outputs. This way, whenever the
|
||||
# outer graph uses an original vt, it uses the subgraph output.
|
||||
#
|
||||
# This is critical for maintaining the separation between:
|
||||
# - `body_r`: The output VT structure that Dynamo continues tracing (may
|
||||
# contain non-proxyable objects, nested structures, etc.)
|
||||
# - `graph_output_vts`: Only the tensor/symint VTs that were actual graph
|
||||
# outputs from speculate_subgraph
|
||||
#
|
||||
# By overwriting the proxies of VTs in `body_r` with the proxies from the
|
||||
# HOP call, we ensure the outer graph correctly references the HOP outputs
|
||||
# while still allowing `body_r` to contain arbitrary Python objects.
|
||||
for orig_vt, subgraph_vt in zip(graph_output_vts, flat_variable.items):
|
||||
if isinstance(orig_vt, (variables.SymNodeVariable, variables.TensorVariable)):
|
||||
assert isinstance(
|
||||
subgraph_vt, (variables.SymNodeVariable, variables.TensorVariable)
|
||||
)
|
||||
orig_vt.proxy = subgraph_vt.proxy
|
||||
|
||||
|
||||
def _call_function_with_auto_output_flattening(
|
||||
tx: "InstructionTranslator",
|
||||
fn: Any,
|
||||
@ -282,44 +343,10 @@ def _call_function_with_auto_output_flattening(
|
||||
Returns:
|
||||
The body_r VT (unchanged), which Dynamo will continue tracing with
|
||||
"""
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
# Store the invocation as a call
|
||||
flat_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
fn,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
),
|
||||
example_value=flat_example_value,
|
||||
)
|
||||
|
||||
# wrap_fx_proxy creates fresh variable trackers. However, the main program
|
||||
# after the speculate subgraph can still use the original tensor vts that
|
||||
# are still pointing to the nodes present in the subgraph. So, we reproxify
|
||||
# the original tensor vts with the subgraph outputs. This way, whenever the
|
||||
# outer graph uses an original vt, it uses the subgraph output.
|
||||
#
|
||||
# This is critical for maintaining the separation between:
|
||||
# - `body_r`: The output VT structure that Dynamo continues tracing (may
|
||||
# contain non-proxyable objects, nested structures, etc.)
|
||||
# - `graph_output_vts`: Only the tensor/symint VTs that were actual graph
|
||||
# outputs from speculate_subgraph
|
||||
#
|
||||
# By overwriting the proxies of VTs in `body_r` with the proxies from the
|
||||
# HOP call, we ensure the outer graph correctly references the HOP outputs
|
||||
# while still allowing `body_r` to contain arbitrary Python objects.
|
||||
flat_variable = add_call_function(tx, fn, args, kwargs, flat_example_value)
|
||||
if body_r is not None:
|
||||
for orig_vt, subgraph_vt in zip(graph_output_vts, flat_variable.items):
|
||||
if isinstance(
|
||||
orig_vt, (variables.SymNodeVariable, variables.TensorVariable)
|
||||
):
|
||||
assert isinstance(
|
||||
subgraph_vt, (variables.SymNodeVariable, variables.TensorVariable)
|
||||
)
|
||||
orig_vt.proxy = subgraph_vt.proxy
|
||||
overwrite_tensor_vt_proxy(graph_output_vts, flat_variable)
|
||||
return body_r
|
||||
|
||||
|
||||
@ -823,16 +850,8 @@ def validate_args_and_maybe_create_graph_inputs(
|
||||
if set_subgraph_inputs == "automatic":
|
||||
args.append(a)
|
||||
continue
|
||||
elif set_subgraph_inputs == "semi_automatic":
|
||||
if isinstance(a, AutogradFunctionContextVariable):
|
||||
example_value = a.as_proxy().node.meta["example_value"]
|
||||
arg_name = (
|
||||
a.as_proxy().node.name
|
||||
if sub_args_names is None
|
||||
else sub_args_names[idx]
|
||||
)
|
||||
tracer.create_graph_input(arg_name, a.python_type(), example_value)
|
||||
elif a.maybe_fx_node() is not None:
|
||||
elif set_subgraph_inputs == "automatic_with_new_placeholder":
|
||||
if isinstance(a, variables.TensorVariable):
|
||||
node = a.maybe_fx_node()
|
||||
example_value = node.meta["example_value"]
|
||||
arg_name = (
|
||||
@ -1197,7 +1216,7 @@ def speculate_subgraph_with_auto_output_flattening(
|
||||
enable_grad: Optional[bool] = None,
|
||||
# TODO - We can probably just make everyone use automatic for wrap_semantics
|
||||
set_subgraph_inputs: Literal[
|
||||
"automatic", "semi_automatic", "flatten_manual", "manual"
|
||||
"automatic", "automatic_with_new_placeholder", "flatten_manual", "manual"
|
||||
] = "automatic",
|
||||
# Make default False
|
||||
restore_side_effects: bool = True,
|
||||
@ -1288,7 +1307,7 @@ def speculate_subgraph_with_auto_output_flattening(
|
||||
|
||||
assert set_subgraph_inputs in {
|
||||
"automatic",
|
||||
"semi_automatic",
|
||||
"automatic_with_new_placeholder",
|
||||
"flatten_manual",
|
||||
"manual",
|
||||
}, "Please use one of the supported set_subgraph_inputs options."
|
||||
@ -1514,7 +1533,7 @@ def speculate_subgraph(
|
||||
|
||||
assert set_subgraph_inputs in {
|
||||
"automatic",
|
||||
"semi_automatic",
|
||||
"automatic_with_new_placeholder",
|
||||
"flatten_manual",
|
||||
"manual",
|
||||
}, "Please use one of the supported set_subgraph_inputs options."
|
||||
@ -3675,10 +3694,10 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
|
||||
|
||||
class AutogradFunctionApplyVariable(VariableTracker):
|
||||
def __init__(self, fwd_graph, bwd_graph, parent_source, **kwargs) -> None:
|
||||
def __init__(self, fwd_fn, bwd_fn, parent_source, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.fwd_graph = fwd_graph
|
||||
self.bwd_graph = bwd_graph
|
||||
self.fwd_fn = fwd_fn
|
||||
self.bwd_fn = bwd_fn
|
||||
self.parent_source = parent_source
|
||||
|
||||
def call_function(
|
||||
@ -3687,52 +3706,172 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from . import (
|
||||
AutogradFunctionContextVariable,
|
||||
UserDefinedClassVariable,
|
||||
UserFunctionVariable,
|
||||
UserMethodVariable,
|
||||
)
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
"""
|
||||
Consider the following:
|
||||
At the highest level, the goal of tracing an autograd.Function is to
|
||||
essentially emit a new autograd.Function object. To do this, Dynamo
|
||||
traces fwd and bwd graph and then inserts a AutogradFunctionApply HOP in
|
||||
the graph that call the traced fwd and bwd graph in the `forward` and
|
||||
`backward` methods respectively. AOTDispatcher desugars this HOP and
|
||||
just inlines the hop fwd and bwd into the main graph during its tracing.
|
||||
|
||||
However, the traced forward and backward graphs cannot be directly
|
||||
placed in the new autograd.Function because autograd.Function has some
|
||||
requirements.
|
||||
|
||||
a) # fwd graph inputs = # bwd graph outputs
|
||||
b) # fwd graph outputs = # bwd graph inputs
|
||||
c) Since the graphs do not have ctx variable, we have to manually return
|
||||
the saved_tensors from the forward and have additional inputs in the
|
||||
backward, and wire the connections.
|
||||
|
||||
Unfortunately, reworking the initial traced fwd and bwd graphs to
|
||||
satisfy the above 3 conditions leads to a very tedious codebase.
|
||||
|
||||
Lets look at an example
|
||||
|
||||
class Foo:
|
||||
def __init__(self):
|
||||
self.a = 4
|
||||
|
||||
class MySin(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
def forward(ctx, x, foo):
|
||||
ctx.save_for_backward(x)
|
||||
return x.sin()
|
||||
return x.sin() + foo.a
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
x, = ctx.saved_tensors
|
||||
return grad * x.cos()
|
||||
|
||||
We want the resulting graphs to look like:
|
||||
def fwd(ctx, x):
|
||||
|
||||
# Note that Dynamo lifts the foo_a directly as an input.
|
||||
def fwd(ctx, x, foo_a):
|
||||
# (output, saved tensors / attrs)
|
||||
return (x.sin(), [x])
|
||||
# bwd(ctx, grad0, grad1, ..., gradn, *saved_tensors_or_attrs)
|
||||
return (x.sin() + foo_a, (x))
|
||||
|
||||
# Note that backward graph has None as the second output to match the
|
||||
# fwd requirements (even though the original backward function has just
|
||||
# output)
|
||||
def bwd(ctx, grad, x):
|
||||
return grad * x.cos()
|
||||
return grad * x.cos(), None
|
||||
|
||||
|
||||
To accomplish this, we're going to:
|
||||
1. Construct a ctx object
|
||||
2. (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph on MySin.forward (manually_set_inputs=True)
|
||||
3. (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph on MySin.backward, while manually setting
|
||||
the ctx and grad inputs.
|
||||
4. Manually rewriting the fwd graph's output to be (output, stuff_that_gets_used in bwd_graph)
|
||||
Getting from 3 to 4 is pretty elegant: stuff_that_gets_used in bwd graph is
|
||||
just the bwd_freevars returned from speculate_subgraph, assuming MySin.backward
|
||||
doesn't capture any arguments.
|
||||
All these steps work if MySin.backward doesn't capture any values. This is a
|
||||
limitation in general that we should check for.
|
||||
2. Speculate subgraph forward
|
||||
3. Speculate subgraph backward
|
||||
4. rewired_bwd_graph_inputs - Use the traced fwd graph as the anchor point, and rewire the backward graph outputs
|
||||
5. handle_saved_tensors_wiring - Hhandle the saved tensors, as mentioned in (c)
|
||||
"""
|
||||
|
||||
prev_side_effects = tx.output.side_effects.clone()
|
||||
fwd_tracer = torch._dynamo.output_graph.SubgraphTracer(
|
||||
tx.output,
|
||||
parent=tx.output.current_tracer,
|
||||
source_target="autograd.Function",
|
||||
)
|
||||
|
||||
ctx = self.prepare_ctx_vt(tx, args, kwargs)
|
||||
|
||||
fwd_fn, fwd_out, fwd_graph, fwd_freevars, fwd_graph_output_vts = (
|
||||
self.trace_forward_graph(tx, ctx, fwd_tracer, args, kwargs)
|
||||
)
|
||||
|
||||
bwd_args, bwd_out, bwd_graph, bwd_freevars, bwd_graph_output_vts = (
|
||||
self.trace_backward_graph(tx, ctx, fwd_tracer, fwd_out, fwd_fn)
|
||||
)
|
||||
|
||||
# At this point, the fwd_out represents the output of the forward
|
||||
# method. fwd_graph represents the tensor computation, its input and
|
||||
# output do not match the original forward method. Same is true for the
|
||||
# bwd_out and bwd_graph. However, in order to create a new
|
||||
# autograd.Function to pass to the lower compiler, the fwd and bwd graph
|
||||
# must be "consistent".
|
||||
self.rewire_bwd_graph_inputs(
|
||||
fwd_out, fwd_graph, fwd_freevars, bwd_out, bwd_graph, bwd_freevars, args
|
||||
)
|
||||
|
||||
fwd_graph, bwd_graph = self.handle_saved_tensors_wiring(
|
||||
ctx,
|
||||
fwd_out,
|
||||
fwd_graph,
|
||||
fwd_freevars,
|
||||
fwd_graph_output_vts,
|
||||
bwd_out,
|
||||
bwd_graph,
|
||||
bwd_freevars,
|
||||
bwd_args,
|
||||
)
|
||||
|
||||
# If users call ctx.mark_non_differentiable, we should capture these output tensors who
|
||||
# are marked as non-differentiable and pass them to ApplyTemplate
|
||||
# at torch._functorch.autograd_function.AutogradFunctionApply for reconstruction.
|
||||
non_differentiable_idx = []
|
||||
if ctx.non_differentiable is not None:
|
||||
non_differentiable_set = set(ctx.non_differentiable)
|
||||
assert isinstance(fwd_out, variables.BaseListVariable)
|
||||
for i, x in enumerate(fwd_out.items):
|
||||
if (
|
||||
isinstance(x, variables.TensorVariable)
|
||||
and x.as_proxy() in non_differentiable_set
|
||||
):
|
||||
non_differentiable_idx.append(i)
|
||||
|
||||
# Store fwd_body
|
||||
fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
|
||||
fwd_name = tx.output.install_subgraph(
|
||||
"fwd_body",
|
||||
torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph),
|
||||
)
|
||||
|
||||
fwd_node = make_attr(tx, fwd_name)
|
||||
|
||||
# Store bwd_body
|
||||
bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
|
||||
bwd_name = tx.output.install_subgraph(
|
||||
"bwd_body",
|
||||
torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph),
|
||||
)
|
||||
|
||||
bwd_node = make_attr(tx, bwd_name)
|
||||
|
||||
p_args = (
|
||||
fwd_node,
|
||||
bwd_node,
|
||||
*list(fwd_freevars.keys()),
|
||||
)
|
||||
kwargs = {
|
||||
"non_differentiable_idx": non_differentiable_idx,
|
||||
}
|
||||
|
||||
# Store the invocation as a call
|
||||
from torch._functorch.autograd_function import autograd_function_apply
|
||||
|
||||
# We use speculate_subgraph to get the fwd graph, but it's always under no grad mode like what eager mode does.
|
||||
# The fwd outputs (tensor's example_value) need to be inferred from fake tensor prop to get the correct attributes
|
||||
# (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing.
|
||||
# Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it.
|
||||
with enable_python_dispatcher():
|
||||
with tx.output.fake_mode:
|
||||
fwd_freevars_args = [_get_fake_value(arg) for arg in fwd_freevars]
|
||||
fake_args = (
|
||||
tx.output.nn_modules[fwd_node.node.name],
|
||||
tx.output.nn_modules[bwd_node.node.name],
|
||||
*fwd_freevars_args,
|
||||
)
|
||||
example_value = autograd_function_apply(*fake_args, **kwargs)
|
||||
|
||||
flat_variable = add_call_function(
|
||||
tx, autograd_function_apply, p_args, kwargs, example_value
|
||||
)
|
||||
overwrite_tensor_vt_proxy(fwd_graph_output_vts, flat_variable)
|
||||
overwrite_tensor_vt_requires_grad(fwd_graph_output_vts, flat_variable)
|
||||
return fwd_out
|
||||
|
||||
def prepare_ctx_vt(self, tx, args, kwargs):
|
||||
from . import AutogradFunctionContextVariable
|
||||
|
||||
ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
|
||||
with discard_graph_changes(tx):
|
||||
# A little hacky, but we need a dummy ctx proxy for speculate_subgraph.
|
||||
@ -3742,37 +3881,39 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
||||
)
|
||||
set_example_value(proxy.node, ctx.value)
|
||||
ctx.proxy = proxy
|
||||
return ctx
|
||||
|
||||
if isinstance(self.fwd_graph, types.FunctionType):
|
||||
fwd_fn = UserFunctionVariable(self.fwd_graph)
|
||||
fwd_args = [ctx, *args]
|
||||
elif isinstance(self.fwd_graph, types.MethodType):
|
||||
fwd_fn = UserMethodVariable(
|
||||
self.fwd_graph.__func__,
|
||||
UserDefinedClassVariable(self.fwd_graph.__class__),
|
||||
)
|
||||
fwd_args = [fwd_fn.obj, ctx, *args]
|
||||
else:
|
||||
unimplemented(
|
||||
gb_type="autograd.Function.apply: non-function or method forward",
|
||||
context=str(self.fwd_graph),
|
||||
explanation="Expected forward function to be a function or method.",
|
||||
hints=[],
|
||||
)
|
||||
def trace_forward_graph(self, tx, ctx, fwd_tracer, args, kwargs):
|
||||
from torch._functorch.autograd_function import autograd_function_trace_helper
|
||||
|
||||
fwd_fn, fwd_args = self.prepare_fn_vt(ctx, "forward", args)
|
||||
|
||||
# autograd.Function forward does a few things like running in no_grad
|
||||
# mode and also appying view_as for input tensors that are returned as
|
||||
# outputs. Therefore, we wrap the original forward in a helper that have
|
||||
# those extra bits for Dynamo to trace.
|
||||
fwd_fn = _make_inlined(tx, autograd_function_trace_helper)(fwd_fn)
|
||||
|
||||
# Speculate subgraph on the fwd
|
||||
(fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph(
|
||||
tx,
|
||||
fwd_fn,
|
||||
fwd_args,
|
||||
kwargs,
|
||||
"autograd.Function",
|
||||
enable_grad=False,
|
||||
set_subgraph_inputs="semi_automatic",
|
||||
restore_side_effects=False,
|
||||
tracer=fwd_tracer,
|
||||
fwd_out, fwd_graph, fwd_freevars, fwd_graph_output_vts = (
|
||||
speculate_subgraph_with_auto_output_flattening(
|
||||
tx,
|
||||
fwd_fn,
|
||||
fwd_args,
|
||||
kwargs,
|
||||
"autograd.Function",
|
||||
enable_grad=None,
|
||||
set_subgraph_inputs="automatic",
|
||||
restore_side_effects=False,
|
||||
tracer=fwd_tracer,
|
||||
)
|
||||
)
|
||||
|
||||
# For inputs that are not used but need to be captured, so that the gradient is tracked.
|
||||
for arg in args:
|
||||
if isinstance(arg, variables.TensorVariable):
|
||||
fwd_tracer.maybe_lift_tracked_freevar_to_input(arg.as_proxy())
|
||||
|
||||
if ctx in tx.output.side_effects.store_attr_mutations:
|
||||
if (
|
||||
"_materialize_non_diff_grads"
|
||||
@ -3786,39 +3927,38 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
return fwd_fn, fwd_out, fwd_graph, fwd_freevars, fwd_graph_output_vts
|
||||
|
||||
def trace_backward_graph(self, tx, ctx, fwd_tracer, fwd_out, fwd_fn):
|
||||
from . import UserDefinedClassVariable, UserFunctionVariable, UserMethodVariable
|
||||
|
||||
# Note that for the forward, we do not restore side effects, because we
|
||||
# want the later tracing to see the side-effects. But for backward, we
|
||||
# are just trying to capture the graph, and therefore we must restore
|
||||
# the side effects.
|
||||
prev_side_effects = tx.output.side_effects
|
||||
|
||||
# Speculate subgraph on the backward. We make the bwd tracer a child of
|
||||
# the fwd tracer, because backward may rely on tensors/attrs created in
|
||||
# the fwd tracer.
|
||||
bwd_tracer = torch._dynamo.output_graph.SubgraphTracer(
|
||||
tx.output,
|
||||
parent=fwd_tracer,
|
||||
source_target="autograd.Function",
|
||||
)
|
||||
|
||||
# Speculate subgraph on the backward. We make the
|
||||
# bwd tracer a child of the fwd tracer, because backward may rely on
|
||||
# tensors/attrs created in the fwd tracer.
|
||||
|
||||
if isinstance(fwd_out, variables.BaseListVariable):
|
||||
bwd_args = [ctx, *fwd_out.items]
|
||||
bwd_args = []
|
||||
if isinstance(fwd_out, variables.TensorVariable):
|
||||
bwd_args.append(fwd_out)
|
||||
else:
|
||||
bwd_args = [ctx, fwd_out]
|
||||
assert isinstance(fwd_out, variables.BaseListVariable)
|
||||
for i in fwd_out.items:
|
||||
if isinstance(i, variables.TensorVariable):
|
||||
bwd_args.append(i)
|
||||
else:
|
||||
bwd_args.append(ConstantVariable.create(None))
|
||||
|
||||
bwd_src = AttrSource(self.parent_source, member="backward")
|
||||
if isinstance(self.bwd_graph, types.FunctionType):
|
||||
bwd_fn = UserFunctionVariable(self.bwd_graph, source=bwd_src)
|
||||
elif isinstance(self.bwd_graph, types.MethodType):
|
||||
bwd_fn = UserMethodVariable(
|
||||
self.bwd_graph.__func__,
|
||||
UserDefinedClassVariable(self.bwd_graph.__class__),
|
||||
source=bwd_src,
|
||||
)
|
||||
bwd_args = [bwd_fn.obj, *bwd_args]
|
||||
else:
|
||||
unimplemented(
|
||||
gb_type="autograd.Function.apply: non-function or method backward",
|
||||
context=str(self.bwd_graph),
|
||||
explanation="Expected backward function to be a function or method.",
|
||||
hints=[],
|
||||
)
|
||||
bwd_fn, bwd_args = self.prepare_fn_vt(ctx, "backward", bwd_args)
|
||||
|
||||
def is_strict_for(v: VariableTracker):
|
||||
if isinstance(v, variables.TensorVariable):
|
||||
@ -3826,21 +3966,27 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
||||
return v.proxy.tracer is not fwd_tracer
|
||||
return True
|
||||
|
||||
# automatic with new placeholder relies on the function arg names to
|
||||
# create a new proxy. Also, it will always INSERT a tensor placeholder
|
||||
# as input, even though it might not be used in the graph. This allows
|
||||
# us to make a mapping for the backward graph.
|
||||
with (
|
||||
tx.output.subtracer(fwd_fn, fwd_tracer),
|
||||
tx.strict_translation_mode(is_strict_for),
|
||||
):
|
||||
try:
|
||||
(bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph(
|
||||
tx,
|
||||
bwd_fn,
|
||||
bwd_args,
|
||||
kwargs,
|
||||
"autograd.Function",
|
||||
enable_grad=False,
|
||||
set_subgraph_inputs="manual",
|
||||
restore_side_effects=False,
|
||||
tracer=bwd_tracer,
|
||||
bwd_out, bwd_graph, bwd_freevars, bwd_graph_output_vts = (
|
||||
speculate_subgraph_with_auto_output_flattening(
|
||||
tx,
|
||||
bwd_fn,
|
||||
bwd_args,
|
||||
{},
|
||||
"autograd.Function",
|
||||
enable_grad=False,
|
||||
set_subgraph_inputs="automatic_with_new_placeholder",
|
||||
restore_side_effects=False,
|
||||
tracer=bwd_tracer,
|
||||
)
|
||||
)
|
||||
except torch._dynamo.exc.Unsupported as e:
|
||||
if isinstance(
|
||||
@ -3857,16 +4003,14 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
||||
autograd_function_backward_rewritten,
|
||||
)
|
||||
|
||||
if isinstance(self.bwd_graph, types.FunctionType):
|
||||
if isinstance(self.bwd_fn, types.FunctionType):
|
||||
bwd_fn = UserFunctionVariable(
|
||||
autograd_function_backward_rewritten(self.bwd_graph)
|
||||
autograd_function_backward_rewritten(self.bwd_fn)
|
||||
)
|
||||
elif isinstance(self.bwd_graph, types.MethodType):
|
||||
elif isinstance(self.bwd_fn, types.MethodType):
|
||||
bwd_fn = UserMethodVariable(
|
||||
autograd_function_backward_rewritten(
|
||||
self.bwd_graph.__func__
|
||||
),
|
||||
UserDefinedClassVariable(self.bwd_graph.__class__),
|
||||
autograd_function_backward_rewritten(self.bwd_fn.__func__),
|
||||
UserDefinedClassVariable(self.bwd_fn.__class__),
|
||||
)
|
||||
else:
|
||||
unimplemented(
|
||||
@ -3880,190 +4024,215 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
||||
"torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops",
|
||||
[],
|
||||
):
|
||||
(bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph(
|
||||
tx,
|
||||
bwd_fn,
|
||||
bwd_args,
|
||||
kwargs,
|
||||
"autograd.Function",
|
||||
enable_grad=False,
|
||||
set_subgraph_inputs="manual",
|
||||
restore_side_effects=False,
|
||||
tracer=bwd_tracer,
|
||||
bwd_out, bwd_graph, bwd_freevars, bwd_graph_output_vts = (
|
||||
speculate_subgraph_with_auto_output_flattening(
|
||||
tx,
|
||||
bwd_fn,
|
||||
bwd_args,
|
||||
{},
|
||||
"autograd.Function",
|
||||
enable_grad=False,
|
||||
set_subgraph_inputs="automatic_with_new_placeholder",
|
||||
restore_side_effects=False,
|
||||
tracer=bwd_tracer,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
tx.output.side_effects = prev_side_effects
|
||||
return bwd_args, bwd_out, bwd_graph, bwd_freevars, bwd_graph_output_vts
|
||||
|
||||
# TODO: assert that bwd_graph didn't capture values that were
|
||||
# not created inside fwd_graph.
|
||||
def rewire_bwd_graph_inputs(
|
||||
self,
|
||||
fwd_out,
|
||||
fwd_graph,
|
||||
fwd_freevars,
|
||||
bwd_out,
|
||||
bwd_graph,
|
||||
bwd_freevars,
|
||||
orig_fwd_args,
|
||||
):
|
||||
# Ensure fwd-input and bwd-output consistency - autograd.Function
|
||||
# requires that the inputs of the forward line up correctly with the
|
||||
# outputs of the backward. This is the responsibily of the user. Now,
|
||||
# when Dynamo is creating a new autograd.Function, it is now Dynamo
|
||||
# responsibility to do this lineup. To do this, we use the original user
|
||||
# point as the anchor point to provide this mapping.
|
||||
|
||||
# TODO(oulgen): Ideally, we would not do a linear search for output
|
||||
# node but as things currently are there could be nodes after the
|
||||
# output node
|
||||
# This is bug prone as if there's code after the output node, then
|
||||
# graph.output will append the output at the very end
|
||||
# This might be a behavior difference
|
||||
# Some more description to understand the following codebase
|
||||
#
|
||||
# fwd_freevars/bwd_freevars: A map from outer graph proxy to inner graph
|
||||
# placeholder proxy. The keys are ALWAYS outer graph proxy, these could
|
||||
# be inputs in the main graph or also intermediates in the main graph
|
||||
# that are passed as inputs to the subgraph.
|
||||
#
|
||||
# orig_fwd_args - Variable trackers for the forward graph inputs. Since
|
||||
# these are inputs, the tensor variables here are all OUTER graph proxies.
|
||||
|
||||
# If users call ctx.mark_non_differentiable, we should capture these output tensors who
|
||||
# are marked as non-differentiable and pass them to ApplyTemplate
|
||||
# at torch._functorch.autograd_function.AutogradFunctionApply for reconstruction.
|
||||
non_differentiable_idx = []
|
||||
if ctx.non_differentiable is not None:
|
||||
non_differentiable_set = set(ctx.non_differentiable)
|
||||
assert isinstance(fwd_out, variables.BaseListVariable)
|
||||
for i, x in enumerate(fwd_out.items):
|
||||
if (
|
||||
isinstance(x, variables.TensorVariable)
|
||||
and x.as_proxy() in non_differentiable_set
|
||||
):
|
||||
non_differentiable_idx.append(i)
|
||||
# bwd_outs - Variable trackers for the backward output. Since these are
|
||||
# output, the variable trackers here point to INNER graph proxies. The
|
||||
# special case is when an input is passed directly to the output of the
|
||||
# backward graph, in which case, the variable tracker can still point to
|
||||
# the outer graph proxy.
|
||||
|
||||
# Rewrite the output of fwd_graph to (output, stuff_necessary_for_bwd)
|
||||
for node in fwd_graph.find_nodes(op="output"):
|
||||
fwd_graph.erase_node(node)
|
||||
break
|
||||
# To make the fwd-inputs and bwd-outputs consistent, we just rewire the
|
||||
# backward graph outputs to match with the forward graph inputs. To do
|
||||
# this, we first rely on orig_fwd_args and bwd_outs to make a mapping of
|
||||
# outer_proxy to inner graph proxy. And walk through the fwd_graph
|
||||
# inputs and this map to find the bwd outputs.
|
||||
|
||||
# Because we lift the bwd_freevars as inputs of the bwd_graph,
|
||||
# we have to manually add the bwd_freevars as output of fwd_graph.
|
||||
# However, the bwd_freevars got from speculate_subgraph use the Proxies in the bwd_graph,
|
||||
# we need to convert them to Proxies in the fwd_graph and then generate new fwd_graph output.
|
||||
fwd_proxy_of_bwd_freevars = []
|
||||
for k in bwd_freevars:
|
||||
if k in fwd_freevars:
|
||||
fwd_proxy_of_bwd_freevars.append(fwd_freevars[k])
|
||||
else:
|
||||
fwd_proxy_of_bwd_freevars.append(k)
|
||||
def get_bwd_node(vt):
|
||||
# Backward tensor vt here can be - (1) an intermediate, or (2) input
|
||||
# to the backward graph. If it is an input to the backward graph, we have to lookup bwd_freevars to get the inner proxy.
|
||||
return bwd_freevars.get(vt.proxy, vt.proxy).node
|
||||
|
||||
def unwrap_proxy(x):
|
||||
if isinstance(x, torch.fx.Proxy):
|
||||
return x.node
|
||||
else:
|
||||
assert variables.ConstantVariable.is_literal(x), (
|
||||
f"Only constant is allowed. Got {x}"
|
||||
)
|
||||
return x
|
||||
# Find the mapping between orig_fwd_args and bwd_out
|
||||
outer_fwd_proxy_to_bwd_node = {}
|
||||
if isinstance(bwd_out, variables.BaseListVariable):
|
||||
bwd_outs = bwd_out.items
|
||||
for idx, fwd_arg in enumerate(orig_fwd_args):
|
||||
# We care about tensor args. For non-tensor args, the bwd output returns None.
|
||||
if isinstance(fwd_arg, variables.TensorVariable):
|
||||
bwd_out_at_idx = bwd_outs[idx]
|
||||
if isinstance(bwd_out_at_idx, variables.TensorVariable):
|
||||
outer_fwd_proxy_to_bwd_node[fwd_arg.proxy] = get_bwd_node(
|
||||
bwd_out_at_idx
|
||||
)
|
||||
else:
|
||||
# backward can return None at the output
|
||||
assert (
|
||||
isinstance(bwd_out_at_idx, variables.ConstantVariable)
|
||||
and bwd_out_at_idx.value is None
|
||||
)
|
||||
outer_fwd_proxy_to_bwd_node[fwd_arg.proxy] = None
|
||||
|
||||
new_fwd_graph_outputs = (fwd_out.as_proxy(), fwd_proxy_of_bwd_freevars)
|
||||
new_fwd_graph_outputs = pytree.tree_map(unwrap_proxy, new_fwd_graph_outputs)
|
||||
fwd_graph.output(new_fwd_graph_outputs)
|
||||
fwd_graph.lint()
|
||||
elif isinstance(bwd_out, variables.TensorVariable):
|
||||
outer_fwd_proxy_to_bwd_node[orig_fwd_args[0].proxy] = get_bwd_node(bwd_out)
|
||||
|
||||
# Store fwd_body
|
||||
fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
|
||||
fwd_name = tx.output.install_subgraph(
|
||||
"fwd_body",
|
||||
torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph),
|
||||
)
|
||||
# Ideally, we should have walked through the fwd placeholders. But we
|
||||
# can instead walk through the fwd_freevars, which is a insertion sorted
|
||||
# dictionary and therefore represents the outer_proxies for the
|
||||
# placeholder in the same order as that as placeholders.
|
||||
rewired_bwd_outputs = [
|
||||
outer_fwd_proxy_to_bwd_node.get(fwd_proxy) for fwd_proxy in fwd_freevars
|
||||
]
|
||||
|
||||
fwd_node = make_attr(tx, fwd_name)
|
||||
|
||||
# The type of original args can be arbitrary, but we only support basic type in FX graph.
|
||||
# So the speculated subgraph input includes original tensor args and the lifted freevars.
|
||||
# We need to filter out the original tensor args and concat them with the lifted freevars
|
||||
# to generate the proxy args for the FX call_function node.
|
||||
filtered_args = []
|
||||
# A boolean list to mark if the type of corresponding argument is tensor.
|
||||
# This is used to determine if a FX node's argument should be an argument of
|
||||
# ApplyTemplate.forward and if we should skip the output from ApplyTemplate.backward
|
||||
# at torch._functorch.autograd_function.AutogradFunctionApply.
|
||||
args_tensor_mask = [False] * len(args)
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, (variables.TensorVariable, variables.SymNodeVariable)):
|
||||
filtered_args.append(arg)
|
||||
args_tensor_mask[i] = True
|
||||
|
||||
# Rewrite the output of bwd_graph to remove the grad output for the non-Tensor args.
|
||||
new_bwd_graph_outputs = None
|
||||
for node in bwd_graph.find_nodes(op="output"):
|
||||
bwd_graph.erase_node(node)
|
||||
break
|
||||
|
||||
# The same as the above fwd proxies, we need to use the bwd proxies in the bwd_graph
|
||||
# if some of the output is from fwd_freevars.
|
||||
bwd_out_proxy = bwd_out.as_proxy()
|
||||
bwd_proxy_of_fwd_freevars = []
|
||||
if isinstance(bwd_out_proxy, (tuple, list)):
|
||||
for k in bwd_out_proxy:
|
||||
if k in bwd_freevars:
|
||||
bwd_proxy_of_fwd_freevars.append(bwd_freevars[k])
|
||||
else:
|
||||
bwd_proxy_of_fwd_freevars.append(k)
|
||||
else:
|
||||
if bwd_out_proxy in bwd_freevars:
|
||||
bwd_proxy_of_fwd_freevars = bwd_freevars[bwd_out_proxy]
|
||||
else:
|
||||
bwd_proxy_of_fwd_freevars = bwd_out_proxy
|
||||
|
||||
# Remove bwd output for non-Tensor args.
|
||||
output_proxy = bwd_proxy_of_fwd_freevars
|
||||
if isinstance(output_proxy, (tuple, list)):
|
||||
new_bwd_graph_outputs = ()
|
||||
for x, mask in zip(output_proxy, args_tensor_mask):
|
||||
if mask:
|
||||
new_bwd_graph_outputs = new_bwd_graph_outputs + (x,)
|
||||
else:
|
||||
assert x is None, f"Grad of non-Tensor arg {x} is not None."
|
||||
else:
|
||||
new_bwd_graph_outputs = output_proxy
|
||||
|
||||
# Update the bwd graph output.
|
||||
new_bwd_graph_outputs = pytree.tree_map(
|
||||
lambda x: None if x is None else x.node, new_bwd_graph_outputs
|
||||
)
|
||||
bwd_graph.output(new_bwd_graph_outputs)
|
||||
bwd_graph.output(tuple(rewired_bwd_outputs))
|
||||
bwd_graph.lint()
|
||||
|
||||
# Store bwd_body
|
||||
bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
|
||||
bwd_name = tx.output.install_subgraph(
|
||||
"bwd_body",
|
||||
torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph),
|
||||
)
|
||||
def handle_saved_tensors_wiring(
|
||||
self,
|
||||
ctx,
|
||||
fwd_out,
|
||||
fwd_graph,
|
||||
fwd_freevars,
|
||||
fwd_graph_body_outputs,
|
||||
bwd_out,
|
||||
bwd_graph,
|
||||
bwd_freevars,
|
||||
bwd_args,
|
||||
):
|
||||
# First we need to map the existing forward graph outputs to bwd graph inputs.
|
||||
bwd_input_nodes = list(bwd_graph.find_nodes(op="placeholder"))
|
||||
|
||||
bwd_node = make_attr(tx, bwd_name)
|
||||
fwd_vt_to_bwd_node = {}
|
||||
bwd_idx = 0
|
||||
if isinstance(fwd_out, variables.BaseListVariable):
|
||||
for fwd_vt in fwd_out.items:
|
||||
if isinstance(fwd_vt, variables.TensorVariable):
|
||||
fwd_vt_to_bwd_node[fwd_vt] = bwd_input_nodes[bwd_idx]
|
||||
bwd_idx += 1
|
||||
else:
|
||||
if isinstance(fwd_out, variables.TensorVariable):
|
||||
fwd_vt_to_bwd_node[fwd_out] = bwd_input_nodes[bwd_idx]
|
||||
bwd_idx += 1
|
||||
|
||||
tx.output.side_effects = prev_side_effects
|
||||
rewired_bwd_graph_inputs = []
|
||||
for fwd_graph_vt in fwd_graph_body_outputs:
|
||||
rewired_bwd_graph_inputs.append(fwd_vt_to_bwd_node.get(fwd_graph_vt))
|
||||
|
||||
p_args = (
|
||||
fwd_node,
|
||||
bwd_node,
|
||||
*([arg.as_proxy() for arg in filtered_args] + list(fwd_freevars.keys())),
|
||||
)
|
||||
kwargs = {
|
||||
"args_tensor_mask": args_tensor_mask,
|
||||
"non_differentiable_idx": non_differentiable_idx,
|
||||
}
|
||||
# bwd_freevars also contains the symint passed from forward to backward
|
||||
extra_fwd_output_nodes = []
|
||||
for fwd_proxy, bwd_inner_proxy in bwd_freevars.items():
|
||||
# For backward, its easy, just get the node from bwd_inner_proxy
|
||||
rewired_bwd_graph_inputs.append(bwd_inner_proxy.node)
|
||||
|
||||
# Store the invocation as a call
|
||||
from torch._functorch.autograd_function import autograd_function_apply
|
||||
# For the fwd_proxy, it could be a proxy from the outer graph, or it
|
||||
# could be an intermediate.
|
||||
# First ensure that's its inner fwd proxy
|
||||
inner_fwd_proxy = fwd_freevars.get(fwd_proxy, fwd_proxy)
|
||||
|
||||
# We use speculate_subgraph to get the fwd graph, but it's always under no grad mode like what eager mode does.
|
||||
# The fwd outputs (tensor's example_value) need to be inferred from fake tensor prop to get the correct attributes
|
||||
# (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing.
|
||||
# Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it.
|
||||
with enable_python_dispatcher(), tx.output.fake_mode:
|
||||
fake_args = (
|
||||
tx.output.nn_modules[fwd_node.node.name],
|
||||
tx.output.nn_modules[bwd_node.node.name],
|
||||
*(
|
||||
[
|
||||
_get_fake_value(arg)
|
||||
for arg in filtered_args + list(fwd_freevars.keys())
|
||||
]
|
||||
),
|
||||
extra_fwd_output_nodes.append(inner_fwd_proxy.node)
|
||||
|
||||
# We have all the info, lets change the fwd graph
|
||||
fwd_output_nodes = []
|
||||
for node in fwd_graph.find_nodes(op="output"):
|
||||
fwd_output_nodes = node.args[0]
|
||||
fwd_graph.erase_node(node)
|
||||
break
|
||||
|
||||
new_fwd_graph_outputs = (fwd_output_nodes, tuple(extra_fwd_output_nodes))
|
||||
fwd_graph.output(new_fwd_graph_outputs)
|
||||
fwd_graph.lint()
|
||||
|
||||
# Now lets change the bwd graph.
|
||||
new_graph = torch.fx.Graph()
|
||||
env = {}
|
||||
|
||||
count = itertools.count()
|
||||
|
||||
for node in rewired_bwd_graph_inputs:
|
||||
if node is None:
|
||||
new_node = new_graph.placeholder(f"unused_{next(count)}")
|
||||
else:
|
||||
new_node = new_graph.placeholder(node.name)
|
||||
new_node.meta = copy.copy(node.meta)
|
||||
env[node] = new_node
|
||||
|
||||
for node in bwd_graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
assert node in env
|
||||
else:
|
||||
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
||||
env[node].meta = copy.copy(node.meta)
|
||||
|
||||
new_graph.lint()
|
||||
return fwd_graph, new_graph
|
||||
|
||||
def prepare_fn_vt(self, ctx, method_name, args):
|
||||
from . import UserDefinedClassVariable, UserFunctionVariable, UserMethodVariable
|
||||
|
||||
source = None
|
||||
if self.parent_source:
|
||||
source = AttrSource(self.parent_source, member=method_name)
|
||||
|
||||
if method_name == "forward":
|
||||
fn = self.fwd_fn
|
||||
else:
|
||||
fn = self.bwd_fn
|
||||
|
||||
if isinstance(fn, types.FunctionType):
|
||||
fn_vt = UserFunctionVariable(fn, source=source)
|
||||
fn_args = [ctx, *args]
|
||||
elif isinstance(fn, types.MethodType):
|
||||
cls_vt = UserDefinedClassVariable(fn.__class__)
|
||||
fn_vt = UserMethodVariable(
|
||||
fn.__func__,
|
||||
cls_vt,
|
||||
source=source,
|
||||
)
|
||||
example_value = autograd_function_apply(*fake_args, **kwargs)
|
||||
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
autograd_function_apply,
|
||||
args=p_args,
|
||||
kwargs=kwargs,
|
||||
),
|
||||
example_value=example_value,
|
||||
)
|
||||
fn_args = [cls_vt, ctx, *args]
|
||||
else:
|
||||
unimplemented(
|
||||
gb_type="autograd.Function.apply: non-function or method forward",
|
||||
context=str(fn),
|
||||
explanation=f"Expected {method_name} to be a function or method.",
|
||||
hints=[],
|
||||
)
|
||||
return fn_vt, fn_args
|
||||
|
||||
|
||||
def _get_fake_value(x):
|
||||
|
||||
@ -701,6 +701,7 @@ class AutogradFunctionVariable(VariableTracker):
|
||||
VariableTracker.visit(visit, (args, kwargs))
|
||||
|
||||
if requires_grad and torch.is_grad_enabled():
|
||||
source = self.source
|
||||
if config.capture_autograd_function is False:
|
||||
warnings.warn(
|
||||
"The config.capture_autograd_function flag is deprecated, it's now always true."
|
||||
@ -720,6 +721,10 @@ class AutogradFunctionVariable(VariableTracker):
|
||||
forward_fn = autograd_function_forward_rewritten(
|
||||
self.fn_cls.forward, self.fn_cls.setup_context
|
||||
)
|
||||
# The forward points to a new function now, so we can't use the
|
||||
# old source. Later on, we guard specifically on
|
||||
# is_setup_ctx_defined
|
||||
source = None
|
||||
|
||||
vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
|
||||
if vjp_fn is not torch.autograd.Function.vjp:
|
||||
@ -752,29 +757,23 @@ class AutogradFunctionVariable(VariableTracker):
|
||||
|
||||
from .higher_order_ops import AutogradFunctionApplyVariable
|
||||
|
||||
source = self.source
|
||||
if source is None:
|
||||
if source is None and not is_setup_ctx_defined:
|
||||
source = AttrSource(
|
||||
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
|
||||
)
|
||||
apply_source = source and AttrSource(source, member="apply")
|
||||
|
||||
val = AutogradFunctionApplyVariable(
|
||||
forward_fn,
|
||||
self.fn_cls.backward,
|
||||
source,
|
||||
source=AttrSource(source, member="apply"),
|
||||
source=apply_source,
|
||||
).call_function(tx, args, kwargs)
|
||||
# Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping
|
||||
# the forward function, as we don't want to generate guards for new_forward.__closure__
|
||||
# if forward is rewritten by autograd_function_forward_rewritten.
|
||||
# But we still need to generate correct guards for the original forward and setup_context
|
||||
# functions, so we have to add guards manually.
|
||||
if self.source:
|
||||
if self.source and is_setup_ctx_defined:
|
||||
fwd_src = AttrSource(self.source, "forward")
|
||||
install_guard(fwd_src.make_guard(GuardBuilder.CLOSURE_MATCH))
|
||||
if is_setup_ctx_defined:
|
||||
setup_ctx_src = AttrSource(self.source, "setup_context")
|
||||
install_guard(setup_ctx_src.make_guard(GuardBuilder.CLOSURE_MATCH))
|
||||
setup_ctx_src = AttrSource(self.source, "setup_context")
|
||||
install_guard(setup_ctx_src.make_guard(GuardBuilder.CLOSURE_MATCH))
|
||||
|
||||
return val
|
||||
|
||||
|
||||
@ -743,20 +743,14 @@ class AutogradFunctionApply(HigherOrderOperator):
|
||||
|
||||
def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
|
||||
saved_values = None
|
||||
args_tensor_mask = fwd_kwargs["args_tensor_mask"]
|
||||
non_differentiable_idx = fwd_kwargs["non_differentiable_idx"]
|
||||
length_of_tensor_args = sum(args_tensor_mask)
|
||||
# Filter out the original tensor args from fwd_args,
|
||||
# lifted freevars should not be args of ApplyTemplate.apply
|
||||
# since we don't need to calculate the gradients of them.
|
||||
new_fwd_args = fwd_args[:length_of_tensor_args]
|
||||
|
||||
class ApplyTemplate(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, *args):
|
||||
nonlocal saved_values
|
||||
output, saved_values = fwd(None, *fwd_args)
|
||||
output, saved_values = fwd(*args)
|
||||
|
||||
# If users call ctx.mark_non_differentiable() in the original fwd function.
|
||||
if len(non_differentiable_idx) > 0:
|
||||
@ -770,9 +764,45 @@ class AutogradFunctionApply(HigherOrderOperator):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad):
|
||||
return bwd(None, *grad, *saved_values)
|
||||
return bwd(*grad, *saved_values)
|
||||
|
||||
return ApplyTemplate.apply(*new_fwd_args)
|
||||
return ApplyTemplate.apply(*fwd_args)
|
||||
|
||||
|
||||
autograd_function_apply = AutogradFunctionApply()
|
||||
|
||||
|
||||
# autograd.Function forward does more than just running the forward method. Most
|
||||
# of this logic is in C++. Here, we rewrite that functionality in python and let
|
||||
# Dynamo trace it. This is most probably incomplete.
|
||||
def autograd_function_trace_helper(orig_fwd):
|
||||
def inner(*args, **kwargs):
|
||||
with torch.no_grad():
|
||||
outs = orig_fwd(*args, **kwargs)
|
||||
|
||||
# Handle the case where if the input is passed on directly to the output, we call view_as
|
||||
# Refer to https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/custom_function.cpp#L254
|
||||
tensor_args = {arg for arg in args if isinstance(arg, torch.Tensor)}
|
||||
if isinstance(outs, torch.Tensor):
|
||||
if outs in tensor_args:
|
||||
return outs.view_as(outs)
|
||||
else:
|
||||
return outs
|
||||
|
||||
new_outs = []
|
||||
for out in outs:
|
||||
if isinstance(out, torch.Tensor):
|
||||
if out in tensor_args:
|
||||
new_outs.append(out.view_as(out))
|
||||
else:
|
||||
new_outs.append(out)
|
||||
else:
|
||||
new_outs.append(out)
|
||||
return tuple(new_outs)
|
||||
|
||||
# TODO - there is missing functionality here, where autograd.Function
|
||||
# overwrites the requires_grad_ of the output tensors depending on the
|
||||
# `mark_non_differentiable`. Currently, this is handled hackily in
|
||||
# Dynamo, where we just overwrite the variable trackers requires_grad.
|
||||
|
||||
return inner
|
||||
|
||||
Reference in New Issue
Block a user