Trace C functions

This commit is contained in:
Adam Paszke
2017-08-02 14:40:11 -07:00
committed by Soumith Chintala
parent bdcbbeaf68
commit 1c4538e017
18 changed files with 126 additions and 36 deletions

View File

@ -211,6 +211,10 @@ auto PyFunction::name() -> std::string {
return std::string(Py_TYPE(f)->tp_name);
}
auto PyFunction::getSharedPtr() -> std::shared_ptr<Function> {
return THPFunction_asFunction((THPFunction*)obj);
}
}} // namespace torch::autograd
// Traverse and clear are required for supporting Python's GC cycle handling.
@ -322,6 +326,8 @@ static void _mark_dirty(THPFunction *self, t2var_type &t2var,
"variables, but detected that there are %d objects sharing it",
v_counter.var_refcnt());
v_counter++;
// In-place modifications invalidate the trace
variable->cdata->tracing_state.reset();
}
// We're not going to ever need this so let's remove references now
Py_DECREF(self->dirty_tensors);