mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Compiled Autograd] Move to torch::dynamo::autograd namespace (#105854)"
This reverts commit 26e3b4020f01d4fc2b7f63e1de4c94d2c8b362b5. Reverted https://github.com/pytorch/pytorch/pull/105854 on behalf of https://github.com/PaliC due to breaking internal embedded device tests (details shared with author) ([comment](https://github.com/pytorch/pytorch/pull/105854#issuecomment-1650559375))
This commit is contained in:
@ -40,14 +40,14 @@ namespace autograd {
|
||||
struct Edge;
|
||||
struct FunctionPostHook;
|
||||
struct FunctionPreHook;
|
||||
class CompiledNodeArgs;
|
||||
class SwapSavedVariables;
|
||||
|
||||
using tensor_list = std::vector<at::Tensor>;
|
||||
using variable_list = std::vector<Variable>;
|
||||
using edge_list = std::vector<Edge>;
|
||||
using saved_variable_list = std::vector<SavedVariable>;
|
||||
using IndexRange = std::pair<size_t, size_t>;
|
||||
using torch::dynamo::autograd::CompiledNodeArgs;
|
||||
using torch::dynamo::autograd::SwapSavedVariables;
|
||||
|
||||
// Custom deleter to prevent stack overflows.
|
||||
TORCH_API void deleteNode(Node* function);
|
||||
@ -578,7 +578,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
|
||||
// Used by compiled autograd to call apply() with different saved tensors
|
||||
// Implementations should call saved.before() on all attrs, then apply(), then
|
||||
// saved.after() on all attrs in the same order.
|
||||
// saved.after() on all attrs.
|
||||
virtual variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) {
|
||||
|
@ -5,16 +5,12 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::dynamo::autograd {
|
||||
class CompiledNodeArgs;
|
||||
class SwapSavedVariables;
|
||||
} // namespace torch::dynamo::autograd
|
||||
|
||||
// A hook that's called on gradients
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
class CompiledNodeArgs;
|
||||
using Variable = at::Tensor;
|
||||
using variable_list = std::vector<Variable>;
|
||||
|
||||
@ -22,7 +18,7 @@ struct TORCH_API FunctionPreHook {
|
||||
virtual ~FunctionPreHook() = default;
|
||||
virtual variable_list operator()(const variable_list& grads) = 0;
|
||||
// only implemented for python hooks, registers hook with compiled autograd
|
||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
virtual void compiled_args(CompiledNodeArgs& args) {
|
||||
throw std::runtime_error(
|
||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||
typeid(*this).name());
|
||||
@ -35,7 +31,7 @@ struct TORCH_API FunctionPostHook {
|
||||
const variable_list& outputs /* grad_inputs */,
|
||||
const variable_list& inputs /* grad_outputs */) = 0;
|
||||
// only implemented for python hooks, registers hook with compiled autograd
|
||||
virtual void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) {
|
||||
virtual void compiled_args(CompiledNodeArgs& args) {
|
||||
throw std::runtime_error(
|
||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||
typeid(*this).name());
|
||||
|
@ -236,8 +236,7 @@ void PyNode::compiled_args(CompiledNodeArgs& args) {
|
||||
}
|
||||
|
||||
// AotAutograd symints are all dynamic
|
||||
auto prior =
|
||||
args.set_default_dyn_type(torch::dynamo::autograd::SizeInput::DYNAMIC);
|
||||
auto prior = args.set_default_dyn_type(SizeInput::DYNAMIC);
|
||||
args.collect(f->compiled_autograd_symints);
|
||||
args.set_default_dyn_type(prior);
|
||||
|
||||
|
@ -11,7 +11,7 @@ struct PyFunctionTensorPreHook : public FunctionPreHook {
|
||||
PyFunctionTensorPreHook(PyObject* dict, int value_idx);
|
||||
~PyFunctionTensorPreHook() override;
|
||||
variable_list operator()(const variable_list& values) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
PyObject* dict;
|
||||
int value_idx;
|
||||
};
|
||||
@ -20,7 +20,7 @@ struct PyFunctionPreHook : public FunctionPreHook {
|
||||
PyFunctionPreHook(PyObject* dict);
|
||||
~PyFunctionPreHook() override;
|
||||
variable_list operator()(const variable_list& values) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
PyObject* dict;
|
||||
};
|
||||
|
||||
@ -30,7 +30,7 @@ struct PyFunctionPostHook : public FunctionPostHook {
|
||||
variable_list operator()(
|
||||
const variable_list& outputs,
|
||||
const variable_list& inputs) override;
|
||||
void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
|
||||
void compiled_args(CompiledNodeArgs& args) override;
|
||||
PyObject* dict;
|
||||
};
|
||||
|
||||
|
@ -10,8 +10,8 @@
|
||||
|
||||
// see [Note: Compiled Autograd]
|
||||
|
||||
namespace torch::dynamo::autograd {
|
||||
using namespace torch::autograd;
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
struct SizeInput {
|
||||
// Note: int value is still needed when dynamic to pass as an arg
|
||||
@ -597,11 +597,12 @@ class SwapSavedVariables {
|
||||
std::shared_ptr<Node> node;
|
||||
};
|
||||
|
||||
} // namespace torch::dynamo::autograd
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
||||
template <>
|
||||
struct std::hash<torch::dynamo::autograd::CacheKey> {
|
||||
size_t operator()(const torch::dynamo::autograd::CacheKey& k) const {
|
||||
struct std::hash<torch::autograd::CacheKey> {
|
||||
size_t operator()(const torch::autograd::CacheKey& k) const {
|
||||
return k.hash();
|
||||
}
|
||||
};
|
||||
|
@ -10,7 +10,6 @@ static struct PyModuleDef _module =
|
||||
|
||||
namespace torch {
|
||||
namespace dynamo {
|
||||
using torch::dynamo::autograd::torch_c_dynamo_compiled_autograd_init;
|
||||
|
||||
void initDynamoBindings(PyObject* torch) {
|
||||
PyObject* dynamo = PyModule_Create(&_module);
|
||||
|
@ -45,7 +45,9 @@ Notes:
|
||||
- We require non-hook autograd nodes to be tracable.
|
||||
*/
|
||||
|
||||
namespace torch::dynamo::autograd {
|
||||
namespace torch {
|
||||
namespace dynamo {
|
||||
using namespace torch::autograd;
|
||||
using c10::SymInt;
|
||||
|
||||
static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
|
||||
@ -223,19 +225,15 @@ static PyObject* the_autograd_compiler = nullptr;
|
||||
static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args);
|
||||
|
||||
static PyObject* clear_cache(PyObject* dummy, PyObject* args) {
|
||||
HANDLE_TH_ERRORS;
|
||||
CacheNode::root()->clear();
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS;
|
||||
}
|
||||
|
||||
static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) {
|
||||
HANDLE_TH_ERRORS;
|
||||
if (CacheNode::root()->is_empty()) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
Py_RETURN_FALSE;
|
||||
END_HANDLE_TH_ERRORS;
|
||||
}
|
||||
|
||||
static PyMethodDef _methods[] = {
|
||||
@ -472,7 +470,6 @@ variable_list compiled_autograd(
|
||||
}
|
||||
|
||||
static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) {
|
||||
HANDLE_TH_ERRORS;
|
||||
PyObject* obj;
|
||||
if (!PyArg_ParseTuple(args, "O", &obj)) {
|
||||
return nullptr;
|
||||
@ -493,11 +490,11 @@ static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) {
|
||||
} else {
|
||||
return prior;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS;
|
||||
}
|
||||
|
||||
PyObject* torch_c_dynamo_compiled_autograd_init() {
|
||||
return PyModule_Create(&_module);
|
||||
}
|
||||
|
||||
} // namespace torch::dynamo::autograd
|
||||
} // namespace dynamo
|
||||
} // namespace torch
|
||||
|
@ -2,6 +2,6 @@
|
||||
#include <torch/csrc/utils/python_stub.h>
|
||||
|
||||
// see [Note: Compiled Autograd]
|
||||
namespace torch::dynamo::autograd {
|
||||
namespace torch::dynamo {
|
||||
PyObject* torch_c_dynamo_compiled_autograd_init();
|
||||
} // namespace torch::dynamo::autograd
|
||||
} // namespace torch::dynamo
|
||||
|
Reference in New Issue
Block a user