Refactor PyInterpreter to use normal vtables (#84388)

I realized that we can deal with the dead vtable problem by...
introducing another indirection!  The resulting code is worse
(you have to do one more dereference to get to the vtable), but
the reduction in boilerplate is, IMO, worth it.

I did this refactor because I'm about to add a lot more methods
to PyInterpreter to handle expunging SymInt from TensorImpl.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84388
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2022-09-01 13:43:06 -07:00
committed by PyTorch MergeBot
parent 241c99232e
commit f6ce2a442e
14 changed files with 234 additions and 474 deletions

View File

@ -188,12 +188,82 @@ void pushPyOutToStack(
namespace {
std::string concrete_name_fn(const c10::impl::PyInterpreter* self) {
std::stringstream ss;
ss << self;
return ss.str();
template <const char* func_name, typename... Ts>
void concrete_trace_cuda(Ts... args) {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
if (Py_IsInitialized()) {
try {
py::module mod = py::module::import("torch.utils._cuda_trace");
py::object hook = mod.attr(func_name).attr("fire_callbacks");
hook(args...);
} catch (const std::exception& e) {
LOG(ERROR) << "CUDA trace hook execution failed: " << e.what();
}
}
}
static constexpr char trace_cuda_event_creation_fn_name[] =
"CUDAEventCreationCallbacks";
static constexpr char trace_cuda_event_deletion_fn_name[] =
"CUDAEventDeletionCallbacks";
static constexpr char trace_cuda_event_record_fn_name[] =
"CUDAEventRecordCallbacks";
static constexpr char trace_cuda_event_wait_fn_name[] =
"CUDAEventWaitCallbacks";
static constexpr char trace_cuda_memory_allocation_fn_name[] =
"CUDAMemoryAllocationCallbacks";
static constexpr char trace_cuda_memory_deallocation_fn_name[] =
"CUDAMemoryDeallocationCallbacks";
static constexpr char trace_cuda_stream_creation_fn_name[] =
"CUDAStreamCreationCallbacks";
struct ConcretePyInterpreterVTable final
: public c10::impl::PyInterpreterVTable {
std::string name() const override;
void decref(PyObject* pyobj, bool is_tensor) const override;
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override;
void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const override;
bool is_contiguous(const TensorImpl* self) const override;
c10::Device device(const TensorImpl* self) const override;
int64_t dim(const TensorImpl* self) const override;
c10::IntArrayRef strides(const TensorImpl* self) const override;
c10::IntArrayRef sizes(const TensorImpl* self) const override;
c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const override;
c10::Layout layout(const TensorImpl* self) const override;
c10::SymInt sym_numel(const TensorImpl* self) const override;
c10::SymIntArrayRef sym_strides(const TensorImpl* self) const override;
void trace_gpu_event_creation(uintptr_t event) const override {
concrete_trace_cuda<trace_cuda_event_creation_fn_name>(event);
}
void trace_gpu_event_deletion(uintptr_t event) const override {
concrete_trace_cuda<trace_cuda_event_deletion_fn_name>(event);
}
void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
const override {
concrete_trace_cuda<trace_cuda_event_record_fn_name>(event, stream);
}
void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override {
concrete_trace_cuda<trace_cuda_event_wait_fn_name>(event, stream);
}
void trace_gpu_memory_allocation(uintptr_t ptr) const override {
concrete_trace_cuda<trace_cuda_memory_allocation_fn_name>(ptr);
}
void trace_gpu_memory_deallocation(uintptr_t ptr) const override {
concrete_trace_cuda<trace_cuda_memory_deallocation_fn_name>(ptr);
}
void trace_gpu_stream_creation(uintptr_t stream) const override {
concrete_trace_cuda<trace_cuda_stream_creation_fn_name>(stream);
}
};
// NOTE [PyInterpreter::decref takes an `is_tensor` arg]
// Before calling PyInterpreter::decref, we must statically know if the
// pyobj is a Tensor or not.
@ -202,10 +272,8 @@ std::string concrete_name_fn(const c10::impl::PyInterpreter* self) {
// One alternative to this is using PyObject_IsInstance
// to get at this information. However, we don't want to risk an incorrect
// `__instancecheck__` changing the semantics here.
void concrete_decref_fn(
const c10::impl::PyInterpreter* self,
PyObject* pyobj,
bool is_tensor) {
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor)
const {
// Leak the pyobj if not initialized. This can happen if we are running
// exit handlers that are destructing tensors with residual (owned)
// PyObjects stored in them.
@ -235,82 +303,11 @@ void concrete_decref_fn(
Py_DECREF(pyobj);
};
c10::intrusive_ptr<TensorImpl> concrete_detach_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
void concrete_dispatch_fn(
const c10::impl::PyInterpreter*,
const c10::OperatorHandle& op,
torch::jit::Stack* stack);
bool concrete_is_contiguous_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
c10::Device concrete_device_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
int64_t concrete_dim_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
c10::IntArrayRef concrete_strides_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
c10::IntArrayRef concrete_sizes_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
c10::SymIntArrayRef concrete_sym_sizes_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
c10::Layout concrete_layout_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
c10::SymIntArrayRef concrete_sym_strides_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
c10::SymInt concrete_sym_numel_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
template <const char*, typename... Ts>
void concrete_trace_cuda(const c10::impl::PyInterpreter*, Ts...);
static constexpr char trace_cuda_event_creation_fn_name[] =
"CUDAEventCreationCallbacks";
static constexpr char trace_cuda_event_deletion_fn_name[] =
"CUDAEventDeletionCallbacks";
static constexpr char trace_cuda_event_record_fn_name[] =
"CUDAEventRecordCallbacks";
static constexpr char trace_cuda_event_wait_fn_name[] =
"CUDAEventWaitCallbacks";
static constexpr char trace_cuda_memory_allocation_fn_name[] =
"CUDAMemoryAllocationCallbacks";
static constexpr char trace_cuda_memory_deallocation_fn_name[] =
"CUDAMemoryDeallocationCallbacks";
static constexpr char trace_cuda_stream_creation_fn_name[] =
"CUDAStreamCreationCallbacks";
class PyInterpreterHolder {
public:
PyInterpreterHolder()
: impl_(new c10::impl::PyInterpreter(
&concrete_name_fn,
&concrete_decref_fn,
&concrete_detach_fn,
&concrete_dispatch_fn,
&concrete_is_contiguous_fn,
&concrete_device_fn,
&concrete_dim_fn,
&concrete_strides_fn,
&concrete_sizes_fn,
&concrete_sym_sizes_fn,
&concrete_layout_fn,
&concrete_sym_numel_fn,
&concrete_sym_strides_fn,
c10::impl::GPUTraceFunctionWrapper(
&concrete_trace_cuda<trace_cuda_event_creation_fn_name>,
&concrete_trace_cuda<trace_cuda_event_deletion_fn_name>,
&concrete_trace_cuda<trace_cuda_event_record_fn_name>,
&concrete_trace_cuda<trace_cuda_event_wait_fn_name>,
&concrete_trace_cuda<trace_cuda_memory_allocation_fn_name>,
&concrete_trace_cuda<trace_cuda_memory_deallocation_fn_name>,
&concrete_trace_cuda<trace_cuda_stream_creation_fn_name>))) {}
: impl_(new c10::impl::PyInterpreter(new ConcretePyInterpreterVTable())) {
}
// NB: intentionally leaks the memory
~PyInterpreterHolder() {
impl_->disarm();
@ -346,6 +343,12 @@ c10::impl::PyInterpreter* getPyInterpreter() {
return self_interpreter.get();
}
std::string ConcretePyInterpreterVTable::name() const {
std::stringstream ss;
ss << getPyInterpreter();
return ss.str();
}
PyObject* THPVariableClass = nullptr;
PyObject* ParameterClass = nullptr;
@ -2168,10 +2171,9 @@ py::object torchDispatchFromTensorImpl(
TorchFunctionName::TorchDispatch));
}
void concrete_dispatch_fn(
const c10::impl::PyInterpreter*,
void ConcretePyInterpreterVTable::dispatch(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
torch::jit::Stack* stack) const {
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
@ -2245,9 +2247,8 @@ void concrete_dispatch_fn(
op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
}
c10::intrusive_ptr<TensorImpl> concrete_detach_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
c10::intrusive_ptr<TensorImpl> ConcretePyInterpreterVTable::detach(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
@ -2271,9 +2272,8 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(
return res_t.getIntrusivePtr();
}
bool concrete_is_contiguous_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
bool ConcretePyInterpreterVTable::is_contiguous(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
@ -2297,9 +2297,7 @@ bool concrete_is_contiguous_fn(
return PyObject_IsTrue(out.ptr());
}
int64_t concrete_dim_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
@ -2323,9 +2321,8 @@ int64_t concrete_dim_fn(
return THPUtils_unpackLong(out.ptr());
}
c10::Device concrete_device_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
c10::Device ConcretePyInterpreterVTable::device(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
@ -2343,9 +2340,8 @@ c10::Device concrete_device_fn(
return toDevice(out.ptr());
}
c10::IntArrayRef concrete_strides_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
c10::IntArrayRef ConcretePyInterpreterVTable::strides(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
@ -2403,9 +2399,8 @@ static std::vector<int64_t> values_from_buffer(
return result;
}
c10::IntArrayRef concrete_sizes_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
c10::IntArrayRef ConcretePyInterpreterVTable::sizes(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
@ -2435,9 +2430,8 @@ c10::IntArrayRef concrete_sizes_fn(
return c10::IntArrayRef(start, len);
}
c10::SymIntArrayRef concrete_sym_sizes_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS
@ -2476,9 +2470,8 @@ c10::SymIntArrayRef concrete_sym_sizes_fn(
END_HANDLE_TH_ERRORS_PYBIND
}
c10::Layout concrete_layout_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
c10::Layout ConcretePyInterpreterVTable::layout(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
@ -2501,9 +2494,8 @@ c10::Layout concrete_layout_fn(
return toLayout(out.ptr());
}
c10::SymInt concrete_sym_numel_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
c10::SymInt ConcretePyInterpreterVTable::sym_numel(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
@ -2528,25 +2520,8 @@ c10::SymInt concrete_sym_numel_fn(
: c10::SymInt{py::cast<int64_t>(out)};
}
template <const char* func_name, typename... Ts>
void concrete_trace_cuda(const c10::impl::PyInterpreter*, Ts... args) {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
if (Py_IsInitialized()) {
try {
py::module mod = py::module::import("torch.utils._cuda_trace");
py::object hook = mod.attr(func_name).attr("fire_callbacks");
hook(args...);
} catch (const std::exception& e) {
LOG(ERROR) << "CUDA trace hook execution failed: " << e.what();
}
}
}
c10::SymIntArrayRef concrete_sym_strides_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS