mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Expose to python the backward AD view_func (#89586)
This will be useful for other systems (AOTAutograd) that want to replay autograd views. FYI @bdhirsh Pull Request resolved: https://github.com/pytorch/pytorch/pull/89586 Approved by: https://github.com/soulitzer
This commit is contained in:
@ -7236,6 +7236,28 @@ get_out().sum().backward()
|
||||
err_msg = "RuntimeError: one of the variables needed for gradient computation"
|
||||
self.assertTrue(err_msg in e.output.decode("utf-8"))
|
||||
|
||||
def test_view_func_replay(self):
|
||||
def _assert_match_metadata(a, b):
|
||||
self.assertEqual(a.size(), b.size())
|
||||
self.assertEqual(a.stride(), b.stride())
|
||||
self.assertEqual(a.storage_offset(), b.storage_offset())
|
||||
|
||||
def _test_op(fn, inp, args):
|
||||
out = fn(inp, *args)
|
||||
self.assertTrue(out._is_view)
|
||||
self.assertTrue(out._base is inp)
|
||||
|
||||
new_inp = inp.clone()
|
||||
_assert_match_metadata(new_inp, inp)
|
||||
new_out = out._view_func(new_inp)
|
||||
_assert_match_metadata(new_out, out)
|
||||
|
||||
_test_op(torch.select, torch.rand(2, 2), (0, 0))
|
||||
_test_op(torch.as_strided, torch.rand(2, 2), ((4,), (1,)))
|
||||
_test_op(torch.view_as_complex, torch.rand(2, 2), ())
|
||||
_test_op(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat), ())
|
||||
|
||||
|
||||
def index_perm_variable(shape, max_indices):
|
||||
if not isinstance(shape, tuple):
|
||||
shape = (shape,)
|
||||
|
@ -82,7 +82,7 @@ using at::Tensor;
|
||||
// base if needed. Case 5 is handled in fw_grad by reading the forward grad from
|
||||
// the base if needed.
|
||||
|
||||
namespace {
|
||||
namespace utils {
|
||||
|
||||
// Enforcing that the metadata between the primal and tangent are same has two
|
||||
// goals:
|
||||
@ -139,7 +139,8 @@ bool has_same_meta(const Variable& base, const Variable& other) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
} // namespace utils
|
||||
|
||||
// This function is will ensure that the fw_grad_ is properly a view of the base
|
||||
// for inplace ops on Tensors that do not have forward grad originally.
|
||||
@ -219,7 +220,8 @@ void AutogradMeta::set_fw_grad(
|
||||
// Enforce same meta here to make sure that the view op below is
|
||||
// always valid
|
||||
Tensor new_base_fw_grad;
|
||||
if (has_same_meta(new_grad, base) && has_same_meta(new_grad, self)) {
|
||||
if (utils::has_same_meta(new_grad, base) &&
|
||||
utils::has_same_meta(new_grad, self)) {
|
||||
// TODO extend this special case to when the underlying storage of
|
||||
// new_grad can be re-used.
|
||||
new_base_fw_grad = new_grad;
|
||||
@ -248,7 +250,7 @@ void AutogradMeta::set_fw_grad(
|
||||
}
|
||||
|
||||
// Enforce the basic layout constraint
|
||||
if (!has_same_meta(new_grad, self)) {
|
||||
if (!utils::has_same_meta(new_grad, self)) {
|
||||
if (is_view_) {
|
||||
auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
|
@ -684,6 +684,36 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& self = THPVariable_Unpack(self_);
|
||||
TORCH_CHECK(
|
||||
THPVariable_Check(arg),
|
||||
"_view_func expect a single argument that is a Tensor");
|
||||
const auto& new_base = THPVariable_Unpack(arg);
|
||||
|
||||
// Ensure that self is indeed a backward differentiable view
|
||||
auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
|
||||
TORCH_CHECK(
|
||||
diff_view_meta && diff_view_meta->has_bw_view(),
|
||||
"_view_func can only be called on "
|
||||
"a Tensor that is a backward differentiable view.");
|
||||
const auto& view_info = diff_view_meta->get_backward_view();
|
||||
// Ensure that the newly provided base is similar to the original base
|
||||
TORCH_CHECK(
|
||||
torch::autograd::utils::has_same_meta(new_base, view_info.base_),
|
||||
"The new base passed to _view_func must have the same metadata as the Tensors's base");
|
||||
|
||||
// Do the actual view replay
|
||||
if (view_info.has_view_fn()) {
|
||||
return THPVariable_Wrap(view_info.view_fn()(new_base));
|
||||
} else {
|
||||
return THPVariable_Wrap(new_base.as_strided(
|
||||
self.sizes(), self.strides(), self.storage_offset()));
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// Instantiates a subclass of self with the same data.
|
||||
static PyObject* THPVariable_as_subclass(
|
||||
PyObject* _self,
|
||||
@ -1645,6 +1675,7 @@ static PyMethodDef extra_methods[] = {
|
||||
METH_STATIC | METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
|
||||
{"_view_func", THPVariable_view_func, METH_O, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
/* From https://github.com/python/cpython/blob/v3.7.0/Modules/xxsubtype.c
|
||||
|
@ -791,6 +791,11 @@ inline Variable make_variable(
|
||||
return Variable();
|
||||
}
|
||||
|
||||
namespace utils {
|
||||
|
||||
TORCH_API bool has_same_meta(const Variable& base, const Variable& other);
|
||||
|
||||
} // namespace utils
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
||||
|
@ -180,12 +180,12 @@ tensor_list2d broadcast_coalesced(
|
||||
|
||||
unique_type_checker type_checker;
|
||||
at::cuda::CUDAGuard device_guard(devices[0]);
|
||||
for (auto& chunk : utils::take_tensors(tensors, buffer_size)) {
|
||||
for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) {
|
||||
auto type_id = chunk.type_id();
|
||||
type_checker.show(type_id);
|
||||
std::vector<at::Tensor> results;
|
||||
if (chunk.options().is_sparse()) {
|
||||
auto flat_tuple = utils::flatten_sparse_tensors(chunk.tensors);
|
||||
auto flat_tuple = torch::utils::flatten_sparse_tensors(chunk.tensors);
|
||||
auto broadcast_indices = broadcast(flat_tuple.first, devices);
|
||||
auto broadcast_values = broadcast(flat_tuple.second, devices);
|
||||
results.reserve(devices.size());
|
||||
@ -194,20 +194,20 @@ tensor_list2d broadcast_coalesced(
|
||||
auto& device_outputs = outputs[i];
|
||||
auto& inds = broadcast_indices[i];
|
||||
auto& vals = broadcast_values[i];
|
||||
for (const auto& var :
|
||||
utils::unflatten_sparse_tensors(inds, vals, chunk.tensors)) {
|
||||
for (const auto& var : torch::utils::unflatten_sparse_tensors(
|
||||
inds, vals, chunk.tensors)) {
|
||||
// See NOTE [ Version Counter in comm.*_coalesced ]
|
||||
device_outputs.push_back(make_variable(var.tensor_data(), false));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto results =
|
||||
broadcast(utils::flatten_dense_tensors(chunk.tensors), devices);
|
||||
auto results = broadcast(
|
||||
torch::utils::flatten_dense_tensors(chunk.tensors), devices);
|
||||
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
|
||||
device_guard.set_index(devices[i]);
|
||||
auto& device_outputs = outputs[i];
|
||||
for (auto& var :
|
||||
utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
|
||||
torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
|
||||
// See NOTE [ Version Counter in comm.*_coalesced ]
|
||||
device_outputs.push_back(make_variable(var.tensor_data(), false));
|
||||
}
|
||||
@ -218,7 +218,7 @@ tensor_list2d broadcast_coalesced(
|
||||
// If we only saw a single tensor type, then we can skip expensive reordering
|
||||
if (!type_checker.unique) {
|
||||
for (auto& o : outputs)
|
||||
utils::reorder_tensors_like(o, tensors);
|
||||
torch::utils::reorder_tensors_like(o, tensors);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
@ -276,6 +276,7 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
Tensor._typed_storage,
|
||||
Tensor._reduce_ex_internal,
|
||||
Tensor._fix_weakref,
|
||||
Tensor._view_func,
|
||||
Tensor._make_wrapper_subclass,
|
||||
Tensor._python_dispatch.__get__,
|
||||
Tensor._has_symbolic_sizes_strides.__get__,
|
||||
|
Reference in New Issue
Block a user