[ca] side-effect free initial trace: GraphTask (#147796)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147796
Approved by: https://github.com/jansel
ghstack dependencies: #147242
This commit is contained in:
Simon Fan
2025-02-25 19:57:55 -08:00
committed by PyTorch MergeBot
parent 0a2da008f8
commit 5e3069dde8
2 changed files with 10 additions and 7 deletions

View File

@ -138,7 +138,7 @@ struct TORCH_API Engine {
// see [Note: Compiled Autograd]
typedef variable_list (*compiled_autograd_fn)(
const std::shared_ptr<Node>& graph_root,
GraphTask& graph_task,
const GraphTask& graph_task,
bool accumulate_grad,
const edge_list& outputs);
static void set_compiled_autograd(compiled_autograd_fn fn);

View File

@ -800,7 +800,7 @@ static SizeInput::DynType get_default_dyn_type() {
// Only call this function while holding GIL
static CacheNode* _compiled_autograd_impl(
const std::shared_ptr<Node>& graph_root,
GraphTask& graph_task,
const GraphTask& graph_task,
bool accumulate_grad,
const edge_list& output_edges,
THPObjectPtr* graph_arg_inputs,
@ -808,7 +808,10 @@ static CacheNode* _compiled_autograd_impl(
THPObjectPtr* graph_arg_ivalue_args,
THPObjectPtr* graph_arg_hooks,
THPObjectPtr* graph_arg_packed_inputs) {
std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
const std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
std::unordered_map<Node*, int> visited_dependencies;
visited_dependencies.reserve(dependencies.size());
std::vector<std::shared_ptr<Node>> worklist{graph_root};
AutogradCompilerCall compiler_call(get_default_dyn_type());
@ -872,9 +875,9 @@ static CacheNode* _compiled_autograd_impl(
}
}
auto it = dependencies.find(edge.function.get());
TORCH_INTERNAL_ASSERT(it != dependencies.end());
if (--it->second == 0) {
dependencies.erase(it);
int count = ++visited_dependencies[it->first];
TORCH_INTERNAL_ASSERT(count <= it->second);
if (count == it->second) {
worklist.emplace_back(edge.function);
}
}
@ -1090,7 +1093,7 @@ struct LockGuardWithErrorLogs {
static variable_list compiled_autograd(
const std::shared_ptr<Node>& graph_root,
GraphTask& graph_task,
const GraphTask& graph_task,
bool accumulate_grad,
const edge_list& output_edges) {
TORCH_CHECK(