mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add pure view support in autograd Function (#164736)
This is the same as https://github.com/pytorch/pytorch/pull/164467 But it needs to be co-deved due to internal insanity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164736 Approved by: https://github.com/soulitzer
This commit is contained in:
@ -8223,7 +8223,8 @@ for shape in [(1,), ()]:
|
||||
|
||||
class IdOneOutput(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, make_view):
|
||||
def forward(ctx, a, make_view, pure_view):
|
||||
ctx._is_pure_view = pure_view
|
||||
if make_view:
|
||||
a = a.narrow(0, 0, 2)
|
||||
else:
|
||||
@ -8237,7 +8238,8 @@ for shape in [(1,), ()]:
|
||||
|
||||
class IdTwoOutput(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, make_view):
|
||||
def forward(ctx, a, b, make_view, pure_view):
|
||||
ctx._is_pure_view = pure_view
|
||||
if make_view:
|
||||
a = a.narrow(0, 0, 2)
|
||||
else:
|
||||
@ -8251,11 +8253,12 @@ for shape in [(1,), ()]:
|
||||
ga_nz[0] = False
|
||||
else:
|
||||
ga_nz[0] = True
|
||||
return ga + gab, gab, None
|
||||
return ga + gab, gab, None, None
|
||||
|
||||
class ViewOfTemp(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, make_view):
|
||||
def forward(ctx, a, make_view, pure_view):
|
||||
ctx._is_pure_view = pure_view
|
||||
ctx.save_for_backward(a)
|
||||
if make_view:
|
||||
a = a.narrow(0, 0, 2)
|
||||
@ -8270,7 +8273,7 @@ for shape in [(1,), ()]:
|
||||
(a,) = ctx.saved_tensors
|
||||
res = torch.zeros_like(a)
|
||||
res.select(0, 0).copy_(grad)
|
||||
return res, None
|
||||
return res, None, None
|
||||
|
||||
fn_id_to_inplace_on_view_err_msg = {
|
||||
"one_output": (
|
||||
@ -8279,71 +8282,96 @@ for shape in [(1,), ()]:
|
||||
),
|
||||
"two_output": (
|
||||
"Output 0 of IdTwoOutputBackward is a view and is being modified inplace."
|
||||
" This view is the output of a function that returns multiple views."
|
||||
" This view is the output of a function that returns multiple views.",
|
||||
"Pure view custom Function can only have one input Tensor and one output Tensor."
|
||||
" Open an issue if you need to support more.",
|
||||
),
|
||||
"view_of_temp": (
|
||||
"Output 0 of ViewOfTempBackward is a view and is being "
|
||||
"modified inplace. This view was created inside a custom Function"
|
||||
"modified inplace. This view was created inside a custom Function",
|
||||
"a view of a leaf Variable that requires grad is being used in an in-place operation",
|
||||
),
|
||||
}
|
||||
|
||||
for fn_id in ["one_output", "two_output", "view_of_temp"]:
|
||||
for inplace in [True, False]:
|
||||
for make_view in [True, False]:
|
||||
# Used for special casing the tests below
|
||||
output_is_a_view = make_view or fn_id == "view_of_temp"
|
||||
for pure_view in [True, False]:
|
||||
# Used for special casing the tests below
|
||||
output_is_a_view = make_view or fn_id == "view_of_temp"
|
||||
|
||||
def fn(a, b):
|
||||
# never modify a, b inplace for gracheck
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
if fn_id == "two_output":
|
||||
tmp1, tmp2 = IdTwoOutput.apply(a, b, make_view)
|
||||
if inplace:
|
||||
tmp1 += 3
|
||||
tmp2 += 3
|
||||
def fn(a, b):
|
||||
# never modify a, b inplace for gracheck
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
if fn_id == "two_output":
|
||||
tmp1, tmp2 = IdTwoOutput.apply(
|
||||
a, b, make_view, pure_view
|
||||
)
|
||||
if inplace:
|
||||
tmp1 += 3
|
||||
tmp2 += 3
|
||||
else:
|
||||
tmp1 = tmp1 + 3
|
||||
tmp2 = tmp2 + 3
|
||||
tmp = tmp1 * tmp2
|
||||
else:
|
||||
tmp1 = tmp1 + 3
|
||||
tmp2 = tmp2 + 3
|
||||
tmp = tmp1 * tmp2
|
||||
if fn_id == "one_output":
|
||||
tmp = IdOneOutput.apply(a, make_view, pure_view)
|
||||
else:
|
||||
tmp = ViewOfTemp.apply(a + b, make_view, pure_view)
|
||||
if inplace:
|
||||
tmp += 3
|
||||
else:
|
||||
tmp = tmp + 3
|
||||
|
||||
return tmp.sum()
|
||||
|
||||
a = torch.ones(2, dtype=dtype, requires_grad=True)
|
||||
b = torch.ones(2, dtype=dtype, requires_grad=True)
|
||||
|
||||
err_msg = fn_id_to_inplace_on_view_err_msg[fn_id][
|
||||
int(pure_view)
|
||||
]
|
||||
|
||||
will_raise_error = (
|
||||
(pure_view and fn_id == "two_output")
|
||||
or (pure_view and fn_id == "view_of_temp" and inplace)
|
||||
or (not pure_view and inplace and output_is_a_view)
|
||||
)
|
||||
|
||||
if will_raise_error:
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
gradcheck(fn, (a, b), check_batched_grad=False)
|
||||
else:
|
||||
if fn_id == "one_output":
|
||||
tmp = IdOneOutput.apply(a, b, make_view)
|
||||
else:
|
||||
tmp = ViewOfTemp.apply(a + b, make_view)
|
||||
if inplace:
|
||||
tmp += 3
|
||||
else:
|
||||
tmp = tmp + 3
|
||||
gradcheck(fn, (a, b), check_batched_grad=False)
|
||||
|
||||
return tmp.sum()
|
||||
# Was the custom backward called properly
|
||||
bw_called[0] = 0
|
||||
ga_nz[0] = True # For the case where the backward is called
|
||||
|
||||
a = torch.ones(2, dtype=dtype, requires_grad=True)
|
||||
b = torch.ones(2, dtype=dtype, requires_grad=True)
|
||||
expected_called = 1
|
||||
expected_ga_nz = True
|
||||
|
||||
err_msg = fn_id_to_inplace_on_view_err_msg[fn_id]
|
||||
if will_raise_error:
|
||||
expected_called = 0
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
fn(a, b)
|
||||
else:
|
||||
fn(a, b).abs().backward()
|
||||
|
||||
if not inplace or not output_is_a_view:
|
||||
gradcheck(fn, (a, b), check_batched_grad=False)
|
||||
if (
|
||||
fn_id == "one_output"
|
||||
and inplace
|
||||
and output_is_a_view
|
||||
and pure_view
|
||||
):
|
||||
# We expect the op to have been replayed and we leveraged the pure view
|
||||
# to re-create the graph, so the original backward was not called
|
||||
expected_called = 0
|
||||
|
||||
# Was the custom backward called properly
|
||||
bw_called[0] = 0
|
||||
ga_nz[0] = True # For the case where the backward is called
|
||||
|
||||
if inplace and output_is_a_view:
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
fn(a, b)
|
||||
else:
|
||||
fn(a, b).abs().backward()
|
||||
|
||||
expected_called = 1
|
||||
expected_ga_nz = True
|
||||
|
||||
if output_is_a_view and inplace:
|
||||
expected_called = 0
|
||||
|
||||
self.assertTrue(bw_called[0] == expected_called)
|
||||
self.assertTrue(ga_nz[0] == expected_ga_nz)
|
||||
self.assertTrue(bw_called[0] == expected_called)
|
||||
self.assertTrue(ga_nz[0] == expected_ga_nz)
|
||||
|
||||
def test_autograd_simple_views_python(self):
|
||||
self._do_test_autograd_simple_views_python(torch.double)
|
||||
|
Reference in New Issue
Block a user