mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Replay view with view_func instead of as_strided in meta_utils for NT (#112205)
Currently meta_utils relies on as_strided when handling the view case (recursively meta-ify the base, and then do as_strided to simulate the view), but NestedTensor does not support as_strided today (though maybe it could?), so what we want to do instead is call Tensor. _view_func. Conveniently, _view_func IS always available for nested tensors. A detail to note is that _view_func actually incurs a guard because it needs to perform some metadata checks to make sure the view is still valid. This PR adds Tensor._unsafe_view_func which can avoid that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112205 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
503955f5ec
commit
0cda4c8abe
@ -525,7 +525,10 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
|
||||
static PyObject* view_func_impl(
|
||||
PyObject* self_,
|
||||
PyObject* arg,
|
||||
bool check_has_same_meta) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& self = THPVariable_Unpack(self_);
|
||||
TORCH_CHECK(
|
||||
@ -540,7 +543,8 @@ static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
|
||||
if (diff_view_meta && diff_view_meta->has_bw_view()) {
|
||||
const auto& view_info = diff_view_meta->get_backward_view();
|
||||
// Ensure that the newly provided base is similar to the original base
|
||||
if (torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
|
||||
if (!check_has_same_meta ||
|
||||
torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
|
||||
// Do the actual view replay
|
||||
if (view_info.has_view_fn()) {
|
||||
out = view_info.view_fn()(new_base);
|
||||
@ -554,6 +558,14 @@ static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
|
||||
return view_func_impl(self_, arg, /*check_has_same_meta=*/true);
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_view_func_unsafe(PyObject* self_, PyObject* arg) {
|
||||
return view_func_impl(self_, arg, /*check_has_same_meta=*/false);
|
||||
}
|
||||
|
||||
// Instantiates a subclass of self with the same data.
|
||||
static PyObject* THPVariable_as_subclass(
|
||||
PyObject* _self,
|
||||
@ -1637,6 +1649,7 @@ static PyMethodDef extra_methods[] = {
|
||||
nullptr},
|
||||
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
|
||||
{"_view_func", THPVariable_view_func, METH_O, nullptr},
|
||||
{"_view_func_unsafe", THPVariable_view_func_unsafe, METH_O, nullptr},
|
||||
{nullptr}};
|
||||
|
||||
struct THPVariableMeta {
|
||||
|
Reference in New Issue
Block a user