mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add pure view support in autograd Function (#164736)
This is the same as https://github.com/pytorch/pytorch/pull/164467 But it needs to be co-deved due to internal insanity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164736 Approved by: https://github.com/soulitzer
This commit is contained in:
@ -8223,7 +8223,8 @@ for shape in [(1,), ()]:
|
||||
|
||||
class IdOneOutput(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, make_view):
|
||||
def forward(ctx, a, make_view, pure_view):
|
||||
ctx._is_pure_view = pure_view
|
||||
if make_view:
|
||||
a = a.narrow(0, 0, 2)
|
||||
else:
|
||||
@ -8237,7 +8238,8 @@ for shape in [(1,), ()]:
|
||||
|
||||
class IdTwoOutput(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, make_view):
|
||||
def forward(ctx, a, b, make_view, pure_view):
|
||||
ctx._is_pure_view = pure_view
|
||||
if make_view:
|
||||
a = a.narrow(0, 0, 2)
|
||||
else:
|
||||
@ -8251,11 +8253,12 @@ for shape in [(1,), ()]:
|
||||
ga_nz[0] = False
|
||||
else:
|
||||
ga_nz[0] = True
|
||||
return ga + gab, gab, None
|
||||
return ga + gab, gab, None, None
|
||||
|
||||
class ViewOfTemp(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, make_view):
|
||||
def forward(ctx, a, make_view, pure_view):
|
||||
ctx._is_pure_view = pure_view
|
||||
ctx.save_for_backward(a)
|
||||
if make_view:
|
||||
a = a.narrow(0, 0, 2)
|
||||
@ -8270,7 +8273,7 @@ for shape in [(1,), ()]:
|
||||
(a,) = ctx.saved_tensors
|
||||
res = torch.zeros_like(a)
|
||||
res.select(0, 0).copy_(grad)
|
||||
return res, None
|
||||
return res, None, None
|
||||
|
||||
fn_id_to_inplace_on_view_err_msg = {
|
||||
"one_output": (
|
||||
@ -8279,71 +8282,96 @@ for shape in [(1,), ()]:
|
||||
),
|
||||
"two_output": (
|
||||
"Output 0 of IdTwoOutputBackward is a view and is being modified inplace."
|
||||
" This view is the output of a function that returns multiple views."
|
||||
" This view is the output of a function that returns multiple views.",
|
||||
"Pure view custom Function can only have one input Tensor and one output Tensor."
|
||||
" Open an issue if you need to support more.",
|
||||
),
|
||||
"view_of_temp": (
|
||||
"Output 0 of ViewOfTempBackward is a view and is being "
|
||||
"modified inplace. This view was created inside a custom Function"
|
||||
"modified inplace. This view was created inside a custom Function",
|
||||
"a view of a leaf Variable that requires grad is being used in an in-place operation",
|
||||
),
|
||||
}
|
||||
|
||||
for fn_id in ["one_output", "two_output", "view_of_temp"]:
|
||||
for inplace in [True, False]:
|
||||
for make_view in [True, False]:
|
||||
# Used for special casing the tests below
|
||||
output_is_a_view = make_view or fn_id == "view_of_temp"
|
||||
for pure_view in [True, False]:
|
||||
# Used for special casing the tests below
|
||||
output_is_a_view = make_view or fn_id == "view_of_temp"
|
||||
|
||||
def fn(a, b):
|
||||
# never modify a, b inplace for gracheck
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
if fn_id == "two_output":
|
||||
tmp1, tmp2 = IdTwoOutput.apply(a, b, make_view)
|
||||
if inplace:
|
||||
tmp1 += 3
|
||||
tmp2 += 3
|
||||
def fn(a, b):
|
||||
# never modify a, b inplace for gracheck
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
if fn_id == "two_output":
|
||||
tmp1, tmp2 = IdTwoOutput.apply(
|
||||
a, b, make_view, pure_view
|
||||
)
|
||||
if inplace:
|
||||
tmp1 += 3
|
||||
tmp2 += 3
|
||||
else:
|
||||
tmp1 = tmp1 + 3
|
||||
tmp2 = tmp2 + 3
|
||||
tmp = tmp1 * tmp2
|
||||
else:
|
||||
tmp1 = tmp1 + 3
|
||||
tmp2 = tmp2 + 3
|
||||
tmp = tmp1 * tmp2
|
||||
if fn_id == "one_output":
|
||||
tmp = IdOneOutput.apply(a, make_view, pure_view)
|
||||
else:
|
||||
tmp = ViewOfTemp.apply(a + b, make_view, pure_view)
|
||||
if inplace:
|
||||
tmp += 3
|
||||
else:
|
||||
tmp = tmp + 3
|
||||
|
||||
return tmp.sum()
|
||||
|
||||
a = torch.ones(2, dtype=dtype, requires_grad=True)
|
||||
b = torch.ones(2, dtype=dtype, requires_grad=True)
|
||||
|
||||
err_msg = fn_id_to_inplace_on_view_err_msg[fn_id][
|
||||
int(pure_view)
|
||||
]
|
||||
|
||||
will_raise_error = (
|
||||
(pure_view and fn_id == "two_output")
|
||||
or (pure_view and fn_id == "view_of_temp" and inplace)
|
||||
or (not pure_view and inplace and output_is_a_view)
|
||||
)
|
||||
|
||||
if will_raise_error:
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
gradcheck(fn, (a, b), check_batched_grad=False)
|
||||
else:
|
||||
if fn_id == "one_output":
|
||||
tmp = IdOneOutput.apply(a, b, make_view)
|
||||
else:
|
||||
tmp = ViewOfTemp.apply(a + b, make_view)
|
||||
if inplace:
|
||||
tmp += 3
|
||||
else:
|
||||
tmp = tmp + 3
|
||||
gradcheck(fn, (a, b), check_batched_grad=False)
|
||||
|
||||
return tmp.sum()
|
||||
# Was the custom backward called properly
|
||||
bw_called[0] = 0
|
||||
ga_nz[0] = True # For the case where the backward is called
|
||||
|
||||
a = torch.ones(2, dtype=dtype, requires_grad=True)
|
||||
b = torch.ones(2, dtype=dtype, requires_grad=True)
|
||||
expected_called = 1
|
||||
expected_ga_nz = True
|
||||
|
||||
err_msg = fn_id_to_inplace_on_view_err_msg[fn_id]
|
||||
if will_raise_error:
|
||||
expected_called = 0
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
fn(a, b)
|
||||
else:
|
||||
fn(a, b).abs().backward()
|
||||
|
||||
if not inplace or not output_is_a_view:
|
||||
gradcheck(fn, (a, b), check_batched_grad=False)
|
||||
if (
|
||||
fn_id == "one_output"
|
||||
and inplace
|
||||
and output_is_a_view
|
||||
and pure_view
|
||||
):
|
||||
# We expect the op to have been replayed and we leveraged the pure view
|
||||
# to re-create the graph, so the original backward was not called
|
||||
expected_called = 0
|
||||
|
||||
# Was the custom backward called properly
|
||||
bw_called[0] = 0
|
||||
ga_nz[0] = True # For the case where the backward is called
|
||||
|
||||
if inplace and output_is_a_view:
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
fn(a, b)
|
||||
else:
|
||||
fn(a, b).abs().backward()
|
||||
|
||||
expected_called = 1
|
||||
expected_ga_nz = True
|
||||
|
||||
if output_is_a_view and inplace:
|
||||
expected_called = 0
|
||||
|
||||
self.assertTrue(bw_called[0] == expected_called)
|
||||
self.assertTrue(ga_nz[0] == expected_ga_nz)
|
||||
self.assertTrue(bw_called[0] == expected_called)
|
||||
self.assertTrue(ga_nz[0] == expected_ga_nz)
|
||||
|
||||
def test_autograd_simple_views_python(self):
|
||||
self._do_test_autograd_simple_views_python(torch.double)
|
||||
|
@ -261,7 +261,8 @@ static optional_variable_list _process_backward_mode_ad(
|
||||
const at::ArrayRef<std::optional<Variable>> raw_outputs,
|
||||
const std::shared_ptr<Node>& cdata,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
|
||||
const _view_as_self_fn_t& view_as_self_fn) {
|
||||
const _view_as_self_fn_t& view_as_self_fn,
|
||||
bool pure_view) {
|
||||
auto num_outputs = raw_outputs.size();
|
||||
|
||||
#ifndef STRIP_ERROR_MESSAGES
|
||||
@ -404,7 +405,8 @@ static optional_variable_list _process_backward_mode_ad(
|
||||
if (!(is_input && is_modified) && var.is_view()) {
|
||||
// is_view() => diff_view_meta
|
||||
auto diff_view_meta = impl::get_view_autograd_meta(var);
|
||||
diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION);
|
||||
diff_view_meta->set_creation_meta(
|
||||
pure_view ? CreationMeta::DEFAULT : CreationMeta::IN_CUSTOM_FUNCTION);
|
||||
}
|
||||
|
||||
if (is_differentiable) {
|
||||
@ -448,13 +450,20 @@ optional_variable_list _wrap_outputs(
|
||||
const std::shared_ptr<Node>& cdata,
|
||||
const _jvp_fn_t& jvp_user_function,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
|
||||
const _view_as_self_fn_t& view_as_self_fn) {
|
||||
const _view_as_self_fn_t& view_as_self_fn,
|
||||
bool pure_view) {
|
||||
std::unordered_map<at::TensorImpl*, size_t> inputs_mapping;
|
||||
inputs_mapping.reserve(input_vars.size());
|
||||
for (const auto i : c10::irange(input_vars.size())) {
|
||||
inputs_mapping.emplace(input_vars[i].unsafeGetTensorImpl(), i);
|
||||
}
|
||||
|
||||
// Limit pure views to 1-1 mapping as it is unclear if it is even
|
||||
// possible to have a pure view for N-1 or 1-N.
|
||||
TORCH_CHECK(
|
||||
!pure_view || (input_vars.size() == 1 && raw_outputs.size() == 1),
|
||||
"Pure view custom Function can only have one input Tensor and one output Tensor. Open an issue if you need to support more.");
|
||||
|
||||
auto outputs = _process_backward_mode_ad(
|
||||
inputs_mapping,
|
||||
non_differentiable,
|
||||
@ -462,7 +471,8 @@ optional_variable_list _wrap_outputs(
|
||||
raw_outputs,
|
||||
cdata,
|
||||
to_save_if_setup_context,
|
||||
view_as_self_fn);
|
||||
view_as_self_fn,
|
||||
pure_view);
|
||||
|
||||
// This must happen after the backward processing as we expect the
|
||||
// computations happening here to track backward mode gradients.
|
||||
|
@ -24,7 +24,8 @@ TORCH_API std::vector<std::optional<Variable>> _wrap_outputs(
|
||||
const std::shared_ptr<Node>& cdata,
|
||||
const _jvp_fn_t& jvp_user_function,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
|
||||
const _view_as_self_fn_t& view_as_self_fn);
|
||||
const _view_as_self_fn_t& view_as_self_fn,
|
||||
bool pure_view);
|
||||
|
||||
TORCH_API void check_variable_result(
|
||||
const at::TensorBase& original,
|
||||
@ -523,7 +524,8 @@ auto Function<T>::apply(Args&&... args)
|
||||
is_executable ? node : nullptr,
|
||||
jvp_fn,
|
||||
{},
|
||||
view_as_self_fn);
|
||||
view_as_self_fn,
|
||||
false);
|
||||
|
||||
node->output_info_.reserve(wrapped_outputs.size());
|
||||
for (auto& output : wrapped_outputs) {
|
||||
|
@ -539,6 +539,7 @@ static PyObject* THPFunction_new(
|
||||
new (&self->saved_variables) std::vector<SavedVariable>();
|
||||
new (&self->is_variable_input) std::vector<bool>();
|
||||
self->materialize_grads = true;
|
||||
self->pure_view = false;
|
||||
self->materialize_non_diff_grads = true;
|
||||
return obj;
|
||||
}
|
||||
@ -716,7 +717,8 @@ static void _wrap_outputs(
|
||||
cdata_if_executable,
|
||||
jvp_user_function,
|
||||
to_save_if_setup_context,
|
||||
view_as_self_fn);
|
||||
view_as_self_fn,
|
||||
self->pure_view);
|
||||
|
||||
for (const auto i : c10::irange(num_outputs)) {
|
||||
PyObject* obj = PyTuple_GetItem(raw_output, i);
|
||||
@ -1456,6 +1458,20 @@ int THPFunction_set_materialize_grads(
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
int THPFunction_set_pure_view(
|
||||
THPFunction* self,
|
||||
PyObject* value,
|
||||
void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (!PyBool_Check(value)) {
|
||||
THPUtils_invalidArguments(value, nullptr, "set_pure_view", 1, "(bool)");
|
||||
return -1;
|
||||
}
|
||||
self->pure_view = (value == Py_True);
|
||||
return 0;
|
||||
END_HANDLE_TH_ERRORS_RET(-1)
|
||||
}
|
||||
|
||||
PyObject* THPFunction_get_materialize_non_diff_grads(
|
||||
THPFunction* self,
|
||||
void* _unused) {
|
||||
@ -1730,6 +1746,11 @@ static struct PyGetSetDef THPFunction_properties[] = {
|
||||
(setter)THPFunction_set_materialize_grads,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"_is_pure_view",
|
||||
nullptr,
|
||||
(setter)THPFunction_set_pure_view,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"_materialize_non_diff_grads",
|
||||
(getter)THPFunction_get_materialize_non_diff_grads,
|
||||
(setter)THPFunction_set_materialize_non_diff_grads,
|
||||
|
@ -109,6 +109,10 @@ struct THPFunction {
|
||||
// Default is true.
|
||||
bool materialize_grads;
|
||||
|
||||
// boolean indicating whether the function is a "pure view", meaning that
|
||||
// replaying the view is enough to get a correct backward.
|
||||
bool pure_view;
|
||||
|
||||
// boolean indicating whether to materialize output grad tensors
|
||||
// corresponding to non-differentiable outputs. Normally, someone would
|
||||
// already get this behavior by switching off materialize_grads,
|
||||
|
Reference in New Issue
Block a user