Files
pytorch/torch/csrc/autograd/variable.h
Adam Paszke 0325e2f646 Major autograd refactor
Improves autograd performance by more than 2x and fixes a couple
of bugs. All core functions have been moved to C.
2016-10-13 17:17:49 -07:00

53 lines
1.1 KiB
C

#ifndef THP_VARIABLE_H
#define THP_VARIABLE_H
struct THPVariableVersion {
THPVariableVersion() {
version_block = new int[2];
version_block[0] = 0;
version_block[1] = 1;
};
int operator++(int) { return version_block[0]++; }
int operator*() { return *version_block; }
void join_with(THPVariableVersion &other) {
cleanup();
version_block = other.version_block;
version_block[1]++;
}
void cleanup() {
if (--version_block[1])
return;
delete[] version_block;
version_block = nullptr;
}
~THPVariableVersion() { cleanup(); }
int *version_block;
};
struct THPVariable {
PyObject_HEAD
PyObject *creator;
PyObject *data;
PyObject *grad;
PyObject *backward_hooks;
THPVariableVersion *version_counter;
int output_nr;
char is_volatile;
char requires_grad;
};
bool THPVariable_initModule(PyObject *module);
extern PyObject *THPVariableClass;
PyObject * THPVariable_NewVolatile(PyObject *data);
PyObject * THPVariable_New(PyObject *data, PyObject *creator, char requires_grad);
#define THPVariable_Check(obj) PyObject_IsInstance(obj, THPVariableClass)
#endif