mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ca] side-effect free initial trace: GraphTask (#147796)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147796 Approved by: https://github.com/jansel ghstack dependencies: #147242
This commit is contained in:
committed by
PyTorch MergeBot
parent
0a2da008f8
commit
5e3069dde8
@ -138,7 +138,7 @@ struct TORCH_API Engine {
|
||||
// see [Note: Compiled Autograd]
|
||||
typedef variable_list (*compiled_autograd_fn)(
|
||||
const std::shared_ptr<Node>& graph_root,
|
||||
GraphTask& graph_task,
|
||||
const GraphTask& graph_task,
|
||||
bool accumulate_grad,
|
||||
const edge_list& outputs);
|
||||
static void set_compiled_autograd(compiled_autograd_fn fn);
|
||||
|
@ -800,7 +800,7 @@ static SizeInput::DynType get_default_dyn_type() {
|
||||
// Only call this function while holding GIL
|
||||
static CacheNode* _compiled_autograd_impl(
|
||||
const std::shared_ptr<Node>& graph_root,
|
||||
GraphTask& graph_task,
|
||||
const GraphTask& graph_task,
|
||||
bool accumulate_grad,
|
||||
const edge_list& output_edges,
|
||||
THPObjectPtr* graph_arg_inputs,
|
||||
@ -808,7 +808,10 @@ static CacheNode* _compiled_autograd_impl(
|
||||
THPObjectPtr* graph_arg_ivalue_args,
|
||||
THPObjectPtr* graph_arg_hooks,
|
||||
THPObjectPtr* graph_arg_packed_inputs) {
|
||||
std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
|
||||
const std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
|
||||
std::unordered_map<Node*, int> visited_dependencies;
|
||||
visited_dependencies.reserve(dependencies.size());
|
||||
|
||||
std::vector<std::shared_ptr<Node>> worklist{graph_root};
|
||||
AutogradCompilerCall compiler_call(get_default_dyn_type());
|
||||
|
||||
@ -872,9 +875,9 @@ static CacheNode* _compiled_autograd_impl(
|
||||
}
|
||||
}
|
||||
auto it = dependencies.find(edge.function.get());
|
||||
TORCH_INTERNAL_ASSERT(it != dependencies.end());
|
||||
if (--it->second == 0) {
|
||||
dependencies.erase(it);
|
||||
int count = ++visited_dependencies[it->first];
|
||||
TORCH_INTERNAL_ASSERT(count <= it->second);
|
||||
if (count == it->second) {
|
||||
worklist.emplace_back(edge.function);
|
||||
}
|
||||
}
|
||||
@ -1090,7 +1093,7 @@ struct LockGuardWithErrorLogs {
|
||||
|
||||
static variable_list compiled_autograd(
|
||||
const std::shared_ptr<Node>& graph_root,
|
||||
GraphTask& graph_task,
|
||||
const GraphTask& graph_task,
|
||||
bool accumulate_grad,
|
||||
const edge_list& output_edges) {
|
||||
TORCH_CHECK(
|
||||
|
Reference in New Issue
Block a user