mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63612 This makes Tensor inherit from a new class TensorBase, that provides a subset of Tensor that doesn't directly depend on native_functions.yaml. Code that only includes TensorBase.h with thus not need to be rebuilt every time someone changes an operator signature. Making `Tensor` inherit from this class means that `const TensorBase&` parameters will be callable with an ordinary `Tensor`. I've also made `Tensor` constructible and assignable from `TensorBase` to minimize friction in code mixing the two types. To help enforce that `Tensor.h` and `Functions.h` aren't accidentally included, I've added an error into `Operators.h` if `TORCH_ASSERT_NO_OPERATORS` is defined. We can either set this in the build system for certain folders, or just define it at the top of any file. I've also included an example of manually special-casing the commonly used `contiguous` operator. The inline function's slow path defers to `TensorBase::__dispatch_contiguous` which is defined in `Tensor.cpp`. I've made it so `OptionalTensorRef` is constructible from `TensorBase`, so I can materialize a `Tensor` for use in dispatch without actually increasing its refcount. Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D30728580 Pulled By: ezyang fbshipit-source-id: 2cbc8eee08043382ee6904ea8e743b1286921c03
56 lines
1.6 KiB
C
56 lines
1.6 KiB
C
#pragma once
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <memory>
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/THP_export.h>
|
|
|
|
// Python object that backs torch.autograd.Variable
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
struct THPVariable {
|
|
PyObject_HEAD;
|
|
// Payload
|
|
c10::MaybeOwned<at::Tensor> cdata;
|
|
// Hooks to be run on backwards pass (corresponds to Python attr
|
|
// '_backwards_hooks', set by 'register_hook')
|
|
PyObject* backward_hooks = nullptr;
|
|
};
|
|
|
|
THP_API PyObject *THPVariableClass;
|
|
THP_API PyObject *ParameterClass;
|
|
|
|
bool THPVariable_initModule(PyObject *module);
|
|
THP_API PyObject * THPVariable_Wrap(at::TensorBase var);
|
|
|
|
static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
|
|
// Check that a python object is a `Tensor`, but not a `Tensor` subclass.
|
|
// (A subclass could have different semantics.) The one exception is
|
|
// Parameter, which is used for Python bookkeeping but is equivalent to
|
|
// Tensor as far as C++ is concerned.
|
|
return (
|
|
tp == (PyTypeObject*)THPVariableClass ||
|
|
tp == (PyTypeObject*)ParameterClass
|
|
);
|
|
}
|
|
|
|
static inline bool THPVariable_CheckExact(PyObject *obj) {
|
|
return THPVariable_CheckTypeExact(Py_TYPE(obj));
|
|
}
|
|
|
|
inline bool THPVariable_Check(PyObject *obj)
|
|
{
|
|
return THPVariableClass && PyObject_IsInstance(obj, THPVariableClass);
|
|
}
|
|
|
|
inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
|
|
return *var->cdata;
|
|
}
|
|
|
|
inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
|
|
return THPVariable_Unpack(reinterpret_cast<THPVariable*>(obj));
|
|
}
|
|
|
|
THP_API c10::impl::PyInterpreter* getPyInterpreter();
|