mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add missing attr access check for legacy autograd.Function (#155055)
Fixes https://github.com/pytorch/pytorch/issues/154981 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155055 Approved by: https://github.com/albanD ghstack dependencies: #154509, #154852
This commit is contained in:
committed by
PyTorch MergeBot
parent
5dd07c70e5
commit
1ed243f01c
@ -3725,6 +3725,18 @@ class TestAutograd(TestCase):
|
|||||||
f.next_functions
|
f.next_functions
|
||||||
with self.assertRaisesRegex(RuntimeError, "Attribute 'name' is invalid"):
|
with self.assertRaisesRegex(RuntimeError, "Attribute 'name' is invalid"):
|
||||||
f.name()
|
f.name()
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Attribute '_sequence_nr' is invalid"
|
||||||
|
):
|
||||||
|
f._sequence_nr()
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Attribute '_set_sequence_nr' is invalid"
|
||||||
|
):
|
||||||
|
f._set_sequence_nr(2)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Attribute '_input_metadata' is invalid"
|
||||||
|
):
|
||||||
|
f._input_metadata
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "underlying PyNode has already been deallocated"
|
RuntimeError, "underlying PyNode has already been deallocated"
|
||||||
):
|
):
|
||||||
|
@ -60,6 +60,20 @@ PyObject* THPGradientEdgeClass = nullptr;
|
|||||||
// Anonymous namespace for helpful functions used in this file
|
// Anonymous namespace for helpful functions used in this file
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
inline void check_legacy_fn_attr_access(
|
||||||
|
const std::shared_ptr<torch::autograd::Node>& cdata,
|
||||||
|
const char* attr) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
cdata,
|
||||||
|
"Attribute '",
|
||||||
|
attr,
|
||||||
|
"' is invalid for this instance of _C._FunctionBase. "
|
||||||
|
"Accessing this attribute directly on an instance of autograd.Function "
|
||||||
|
"is a legacy access pattern that is no longer supported. For examples "
|
||||||
|
"on how to use new‑style autograd functions, see "
|
||||||
|
"https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: We shouldn't need to call this function because the engine
|
// TODO: We shouldn't need to call this function because the engine
|
||||||
// can already persist the errors for us. This still seems to be
|
// can already persist the errors for us. This still seems to be
|
||||||
// needed for the DistEngine however.
|
// needed for the DistEngine however.
|
||||||
@ -1142,13 +1156,7 @@ PyObject* process_outputs(
|
|||||||
PyObject* THPFunction_name(PyObject* self, PyObject* noargs) {
|
PyObject* THPFunction_name(PyObject* self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
auto cdata = ((THPFunction*)self)->cdata.lock();
|
auto cdata = ((THPFunction*)self)->cdata.lock();
|
||||||
TORCH_CHECK(
|
check_legacy_fn_attr_access(cdata, "name");
|
||||||
cdata,
|
|
||||||
"Attribute 'name' is invalid for this instance of _C._FunctionBase. "
|
|
||||||
"Accessing this attribute directly on an instance of autograd.Function is a legacy "
|
|
||||||
"access pattern that is no longer supported. For examples on how to use new-style "
|
|
||||||
"autograd functions, see "
|
|
||||||
"https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
|
|
||||||
return THPUtils_packString(cdata->name());
|
return THPUtils_packString(cdata->name());
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
@ -1156,6 +1164,7 @@ PyObject* THPFunction_name(PyObject* self, PyObject* noargs) {
|
|||||||
PyObject* THPFunction_sequence_nr(PyObject* self, PyObject* noargs) {
|
PyObject* THPFunction_sequence_nr(PyObject* self, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS;
|
HANDLE_TH_ERRORS;
|
||||||
auto cdata = ((THPFunction*)self)->cdata.lock();
|
auto cdata = ((THPFunction*)self)->cdata.lock();
|
||||||
|
check_legacy_fn_attr_access(cdata, "_sequence_nr");
|
||||||
return THPUtils_packUInt64(cdata->sequence_nr());
|
return THPUtils_packUInt64(cdata->sequence_nr());
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
@ -1163,6 +1172,7 @@ PyObject* THPFunction_sequence_nr(PyObject* self, PyObject* noargs) {
|
|||||||
PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) {
|
PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) {
|
||||||
HANDLE_TH_ERRORS;
|
HANDLE_TH_ERRORS;
|
||||||
auto cdata = ((THPFunction*)self)->cdata.lock();
|
auto cdata = ((THPFunction*)self)->cdata.lock();
|
||||||
|
check_legacy_fn_attr_access(cdata, "_set_sequence_nr");
|
||||||
cdata->set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
|
cdata->set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
@ -1171,6 +1181,7 @@ PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) {
|
|||||||
PyObject* THPFunction_input_metadata(PyObject* self, void* unused) {
|
PyObject* THPFunction_input_metadata(PyObject* self, void* unused) {
|
||||||
HANDLE_TH_ERRORS;
|
HANDLE_TH_ERRORS;
|
||||||
auto cdata = ((THPFunction*)self)->cdata.lock();
|
auto cdata = ((THPFunction*)self)->cdata.lock();
|
||||||
|
check_legacy_fn_attr_access(cdata, "_input_metadata");
|
||||||
const auto num_inputs = cdata->num_inputs();
|
const auto num_inputs = cdata->num_inputs();
|
||||||
THPObjectPtr list(PyTuple_New(num_inputs));
|
THPObjectPtr list(PyTuple_New(num_inputs));
|
||||||
if (!list) {
|
if (!list) {
|
||||||
@ -1388,13 +1399,7 @@ PyObject* THPFunction__register_hook_dict(PyObject* _self, PyObject* _var) {
|
|||||||
new PyFunctionTensorPreHook(var->backward_hooks, tensor.output_nr()));
|
new PyFunctionTensorPreHook(var->backward_hooks, tensor.output_nr()));
|
||||||
auto self = (THPFunction*)_self;
|
auto self = (THPFunction*)_self;
|
||||||
auto cdata = self->cdata.lock();
|
auto cdata = self->cdata.lock();
|
||||||
TORCH_CHECK(
|
check_legacy_fn_attr_access(cdata, "_register_hook_dict");
|
||||||
cdata,
|
|
||||||
"Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. "
|
|
||||||
"Accessing this attribute directly on an instance of autograd.Function is a legacy "
|
|
||||||
"access pattern that is no longer supported. For examples on how to use new-style "
|
|
||||||
"autograd functions, see "
|
|
||||||
"https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
|
|
||||||
cdata->add_tensor_pre_hook(std::move(hook));
|
cdata->add_tensor_pre_hook(std::move(hook));
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
@ -1404,13 +1409,7 @@ PyObject* THPFunction_register_hook(PyObject* _self, PyObject* hook) {
|
|||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
auto self = (THPFunction*)_self;
|
auto self = (THPFunction*)_self;
|
||||||
auto cdata = self->cdata.lock();
|
auto cdata = self->cdata.lock();
|
||||||
TORCH_CHECK(
|
check_legacy_fn_attr_access(cdata, "register_hook");
|
||||||
cdata,
|
|
||||||
"Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. "
|
|
||||||
"Accessing this attribute directly on an instance of autograd.Function is a legacy "
|
|
||||||
"access pattern that is no longer supported. For examples on how to use new-style "
|
|
||||||
"autograd functions, see "
|
|
||||||
"https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
|
|
||||||
return torch::autograd::registerFunctionHook(*cdata, hook);
|
return torch::autograd::registerFunctionHook(*cdata, hook);
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
@ -1419,13 +1418,7 @@ PyObject* THPFunction_register_prehook(PyObject* _self, PyObject* hook) {
|
|||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
auto self = (THPFunction*)_self;
|
auto self = (THPFunction*)_self;
|
||||||
auto cdata = self->cdata.lock();
|
auto cdata = self->cdata.lock();
|
||||||
TORCH_CHECK(
|
check_legacy_fn_attr_access(cdata, "register_prehook");
|
||||||
cdata,
|
|
||||||
"Attribute 'register_prehook' is invalid for this instance of _C._FunctionBase. "
|
|
||||||
"Accessing this attribute directly on an instance of autograd.Function is a legacy "
|
|
||||||
"access pattern that is no longer supported. For examples on how to use new-style "
|
|
||||||
"autograd functions, see "
|
|
||||||
"https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
|
|
||||||
return torch::autograd::registerFunctionPreHook(*cdata, hook);
|
return torch::autograd::registerFunctionPreHook(*cdata, hook);
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
@ -1568,13 +1561,7 @@ PyObject* THPFunction_raw_saved_tensors(THPFunction* self, void* _unused) {
|
|||||||
PyObject* THPFunction_next_functions(THPFunction* self, void* _unused) {
|
PyObject* THPFunction_next_functions(THPFunction* self, void* _unused) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
auto cdata = self->cdata.lock();
|
auto cdata = self->cdata.lock();
|
||||||
TORCH_CHECK(
|
check_legacy_fn_attr_access(cdata, "next_functions");
|
||||||
cdata,
|
|
||||||
"Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. "
|
|
||||||
"Accessing this attribute directly on an instance of autograd.Function is a legacy "
|
|
||||||
"access pattern that is no longer supported. For examples on how to use new-style "
|
|
||||||
"autograd functions, see "
|
|
||||||
"https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
|
|
||||||
const auto num_outputs = cdata->num_outputs();
|
const auto num_outputs = cdata->num_outputs();
|
||||||
THPObjectPtr result(PyTuple_New(num_outputs));
|
THPObjectPtr result(PyTuple_New(num_outputs));
|
||||||
if (!result)
|
if (!result)
|
||||||
|
Reference in New Issue
Block a user