Make Python op registration work with torchdeploy/multipy (#87162)

See strategy at PythonOpRegistrationTrampoline.cpp for the
big picture.

Along the way, I made OperatorHandle support == and hashing,
and slightly changed the low level python_dispatch impl API
to disallow empty strings for dispatch key, which had the knock
on effect of requiring us to explicitly make sure we pass in
CompositeImplicitAutograd if we would have passed in "" (I didn't apply
this to the rest of the file because I'm lazy.)

Test strategy is we delete the logic for preventing Python op
registrations in torch from being skipped in a torchdeploy context
and show CI still works.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87162
Approved by: https://github.com/anjali411, https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang
2022-11-02 19:08:07 -07:00
committed by PyTorch MergeBot
parent 2f296cfdbb
commit f884e817d4
20 changed files with 527 additions and 106 deletions

View File

@ -0,0 +1,28 @@
#include <ATen/core/PythonOpRegistrationTrampoline.h>
namespace at {
namespace impl {
// The strategy is that all python interpreters attempt to register themselves
// as the main interpreter, but only one wins. Only that interpreter is
// allowed to interact with the C++ dispatcher. Furthermore, when we execute
// logic on that interpreter, we do so hermetically, never setting pyobj field
// on Tensor.
std::atomic<c10::impl::PyInterpreter*> PythonOpRegistrationTrampoline::interpreter_{nullptr};
bool PythonOpRegistrationTrampoline::registerInterpreter(c10::impl::PyInterpreter* interp) {
c10::impl::PyInterpreter* expected = nullptr;
interpreter_.compare_exchange_strong(expected, interp);
if (expected != nullptr) {
// This is the second (or later) Python interpreter, which means we need
// non-trivial hermetic PyObject TLS
c10::impl::HermeticPyObjectTLS::init_state();
return false;
} else {
return true;
}
}
} // namespace impl
} // namespace at

View File

@ -0,0 +1,18 @@
#include <ATen/core/dispatch/Dispatcher.h>
// TODO: this can probably live in c10
namespace at {
namespace impl {
class TORCH_API PythonOpRegistrationTrampoline final {
static std::atomic<c10::impl::PyInterpreter*> interpreter_;
public:
// Returns true if you successfully registered yourself (that means
// you are in the hot seat for doing the operator registrations!)
static bool registerInterpreter(c10::impl::PyInterpreter*);
};
} // namespace impl
} // namespace at

View File

@ -1,6 +1,7 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include <list>
#include <sstream>
#include <chrono>
namespace c10 {
@ -50,7 +51,9 @@ Dispatcher::Dispatcher()
, operatorLookupTable_()
, backendFallbackKernels_()
, listeners_(std::make_unique<detail::RegistrationListenerList>())
, mutex_() {}
, mutex_()
, cond_var_()
{}
Dispatcher::~Dispatcher() = default;
@ -69,6 +72,41 @@ c10::optional<OperatorHandle> Dispatcher::findOp(const OperatorName& overload_na
});
}
// NB: If you add more waitFor* implementations, you also have to add
// appropriate notify_all() calls to the relevant register calls
void Dispatcher::waitForDef(const FunctionSchema& schema) {
using namespace std::chrono_literals;
std::unique_lock<std::mutex> lock(mutex_);
bool r = cond_var_.wait_for(lock, 2s, [&]{
return findOp(schema.operator_name()) != c10::nullopt;
});
TORCH_INTERNAL_ASSERT(r,
"Expected main interpreter to define ", schema.operator_name(),
", but this didn't happen within timeout. Are you trying to load "
"different models in the same torchdeploy/multipy instance? You "
"must warmup each interpreter identically, e.g., import all "
"the same dependencies.");
}
void Dispatcher::waitForImpl(const OperatorName& op_name, c10::optional<c10::DispatchKey> maybe_dk) {
using namespace std::chrono_literals;
std::unique_lock<std::mutex> lock(mutex_);
auto dk = maybe_dk.value_or(DispatchKey::CompositeImplicitAutograd);
auto op = findOrRegisterName_(op_name);
bool r = cond_var_.wait_for(lock, 2s, [&]{
// NB: this is slightly unsound for overrides, but overrides are
// funny business anyway
return op.hasKernelForDispatchKey(dk);
});
TORCH_INTERNAL_ASSERT(r,
"Expected main interpreter to implement ", dk, " for ", op_name,
", but this didn't happen within timeout. Are you trying to load "
"different models in the same torchdeploy/multipy instance? You "
"must warmup each interpreter identically, e.g., import all "
"the same dependencies.");
}
c10::optional<OperatorHandle> Dispatcher::findSchema(const OperatorName& overload_name) {
auto it = findOp(overload_name);
if (it.has_value()) {
@ -175,6 +213,8 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin
++op.operatorDef_->def_count;
++op.operatorDef_->def_and_impl_count;
cond_var_.notify_all();
return RegistrationHandleRAII([this, op, op_name] {
deregisterDef_(op, op_name);
});
@ -227,6 +267,8 @@ RegistrationHandleRAII Dispatcher::registerImpl(
++op.operatorDef_->def_and_impl_count;
cond_var_.notify_all();
return RegistrationHandleRAII([this, op, op_name, dispatch_key, handle] {
deregisterImpl_(op, op_name, dispatch_key, handle);
});
@ -249,6 +291,7 @@ RegistrationHandleRAII Dispatcher::registerName(OperatorName op_name) {
std::lock_guard<std::mutex> lock(mutex_);
auto op = findOrRegisterName_(op_name);
++op.operatorDef_->def_and_impl_count;
return RegistrationHandleRAII(
[this, op, op_name] { deregisterName_(op, op_name); });
}

View File

@ -11,6 +11,7 @@
#include <c10/util/LeftRight.h>
#include <list>
#include <mutex>
#include <condition_variable>
#include <type_traits>
#include <ATen/core/grad_mode.h>
@ -182,6 +183,9 @@ public:
return backendFallbackKernels_[dispatch_ix].kernel.isValid();
}
// Used by torchdeploy/multipy for multiple interpreters racing.
void waitForDef(const FunctionSchema& schema);
void waitForImpl(const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key);
// ------------------------------------------------------------------------
//
@ -307,7 +311,23 @@ private:
std::array<impl::AnnotatedKernel, num_runtime_entries> backendFallbackKernels_;
std::unique_ptr<detail::RegistrationListenerList> listeners_;
// This mutex protects concurrent access to the dispatcher
std::mutex mutex_;
// This condition variable gets notified whenever we add a new def/impl to the
// dispatch table. This is primarily used by multipy/torchdeploy, when
// we have multiple interpreters trying to register to the dispatch table.
// In this situation, whenever the non-primary interpreter would have tried
// to register to the dispatch table, instead it will check to see if the
// expected registration has already been made, and if it hasn't, wait on
// this condition variable to see if it was just racing with the primary
// interpreter.
//
// We expect it to be rare for there to be any waiters on this condition
// variable. This is mostly just to help give better diagnostics if
// something goes horribly wrong
std::condition_variable cond_var_;
};
/**
@ -316,6 +336,8 @@ private:
* to lookup a kernel for a certain set of arguments.
*/
class TORCH_API OperatorHandle {
template <typename T> friend class std::hash;
public:
OperatorHandle(OperatorHandle&&) noexcept = default;
OperatorHandle& operator=(OperatorHandle&&) noexcept = default;
@ -411,6 +433,14 @@ public:
return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor);
}
bool operator==(const OperatorHandle& other) const {
return operatorDef_ == other.operatorDef_;
}
bool operator!=(const OperatorHandle& other) const {
return operatorDef_ != other.operatorDef_;
}
private:
explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
: operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {}
@ -695,3 +725,14 @@ inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet
}
} // namespace c10
namespace std {
template <>
struct hash<c10::OperatorHandle> {
size_t operator()(c10::OperatorHandle op) const noexcept {
return std::hash<void*>{}(static_cast<void*>(op.operatorDef_));
}
};
} // hamespace std

View File

@ -89,7 +89,7 @@ Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, c
// merge everything
#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector<at::Tag>& tags) & {
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector<at::Tag>& tags, _RegisterOrVerify rv) & {
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
DEF_PRELUDE,
"Cannot define an operator inside of a ", toString(kind_), " block. "
@ -125,13 +125,20 @@ Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name
if (out_name) {
*out_name = schema.operator_name(); // copy!
}
registrars_.emplace_back(
c10::Dispatcher::singleton().registerDef(
std::move(schema),
debugString(file_, line_),
tags
)
);
switch (rv) {
case _RegisterOrVerify::REGISTER:
registrars_.emplace_back(
c10::Dispatcher::singleton().registerDef(
std::move(schema),
debugString(file_, line_),
tags
)
);
break;
case _RegisterOrVerify::VERIFY:
c10::Dispatcher::singleton().waitForDef(schema);
break;
}
return *this;
}
#undef DEF_PRELUDE
@ -174,11 +181,10 @@ Library& Library::_def(c10::either<c10::OperatorName, c10::FunctionSchema>&& nam
}
#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
Library& Library::_impl(const char* name_str, CppFunction&& f) & {
at::OperatorName Library::_parseNameForLib(const char* name_str) const {
auto name = torch::jit::parseName(name_str);
auto ns_opt = name.getNamespace();
// This is kind of similar to the checking in def(), but the error
// messages are a little different for this call site
// This is a copy paste of Library::_impl
if (ns_opt.has_value()) {
// See Note [Redundancy in registration code is OK]
TORCH_CHECK(*ns_opt == *ns_,
@ -193,6 +199,11 @@ Library& Library::_impl(const char* name_str, CppFunction&& f) & {
bool b = name.setNamespaceIfNotSet(ns_->c_str());
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
}
return name;
}
Library& Library::_impl(const char* name_str, CppFunction&& f, _RegisterOrVerify rv) & {
at::OperatorName name = _parseNameForLib(name_str);
// See Note [Redundancy in registration code is OK]
TORCH_CHECK(!(f.dispatch_key_.has_value() &&
dispatch_key_.has_value() &&
@ -205,19 +216,30 @@ Library& Library::_impl(const char* name_str, CppFunction&& f) & {
ERROR_CONTEXT
);
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
registrars_.emplace_back(
c10::Dispatcher::singleton().registerImpl(
std::move(name),
dispatch_key,
std::move(f.func_),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(f.cpp_signature_),
std::move(f.schema_),
debugString(std::move(f.debug_), file_, line_)
)
);
switch (rv) {
case _RegisterOrVerify::REGISTER:
registrars_.emplace_back(
c10::Dispatcher::singleton().registerImpl(
std::move(name),
dispatch_key,
std::move(f.func_),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(f.cpp_signature_),
std::move(f.schema_),
debugString(std::move(f.debug_), file_, line_)
)
);
break;
case _RegisterOrVerify::VERIFY:
c10::Dispatcher::singleton().waitForImpl(name, dispatch_key);
break;
}
return *this;
}
c10::OperatorName Library::_resolve(const char* name_str) const {
return _parseNameForLib(name_str);
}
#undef IMPL_PRELUDE
Library& Library::_fallback(CppFunction&& f) & {

View File

@ -9,6 +9,7 @@
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorOptions.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/core/impl/HermeticPyObjectTLS.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/impl/PyInterpreter.h>
#include <c10/core/impl/SizesAndStrides.h>
@ -2037,7 +2038,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return c10::nullopt;
} else if (interpreter == self_interpreter) {
// NB: pyobj_ could still be null!
return c10::make_optional(_unchecked_untagged_pyobj());
if (c10::impl::HermeticPyObjectTLS::get_state()) {
return c10::nullopt;
} else {
return c10::make_optional(_unchecked_untagged_pyobj());
}
} else {
TORCH_CHECK(
false,

View File

@ -0,0 +1,23 @@
#include <c10/core/impl/HermeticPyObjectTLS.h>
namespace c10 {
namespace impl {
thread_local std::atomic<bool> hermeticPyObjectState{false};
std::atomic<bool> HermeticPyObjectTLS::haveState_{false};
void HermeticPyObjectTLS::set_state(bool state) {
hermeticPyObjectState = state;
}
bool HermeticPyObjectTLS::get_tls_state() {
return hermeticPyObjectState;
}
void HermeticPyObjectTLS::init_state() {
haveState_ = true;
}
} // namespace impl
} // namespace c10

View File

@ -0,0 +1,61 @@
#pragma once
#include <c10/macros/Macros.h>
#include <atomic>
namespace c10 {
namespace impl {
// This TLS controls whether or not we permanently associate PyObject
// with Tensor the first time it is allocated. When hermetic PyObject
// TLS is enabled (state is true), we DO NOT save PyObjects to Tensor,
// meaning you get a distinct PyObject whenever you execute the code in
// question.
struct C10_API HermeticPyObjectTLS {
static void set_state(bool state);
static bool get_state() {
// Hypothetical fastpath if torchdeploy/multipy isn't used. Per
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// this qualifies relaxed access because it is a single-location data
// structure (only the boolean here).
//
// Forgetting about data races for a moment, is there a logical race?
//
// - Boolean only ever transitions from false to true. So the
// critical situation is when one interpreter is already running
// when a second interpreter switches haveState from false to true.
//
// - The first interpreter is indifferent whether or not it sees
// hasState true/false; obviously false works (this is what the
// interpreter was previously using; more directly, the interpreter
// calls into itself as the handler, so being hermetic is not
// required), and true simply means serviced python operator calls will
// be hermetic; in these cases it is expected to be functionally
// equivalent.
//
// - The second interpreter MUST see hasState true (as its requests will
// be forwarded to the first interpreter), but it is assumed that there
// is a synchronization between the interpreter initialization, and
// when we actually perform operations, so it is guaranteed to see
// hasState true.
//
// QED.
//
// This fastpath is currently disabled so that we can more easily test that
// hermetic mode works correctly even on stock build of PyTorch.
if (false && !haveState_.load(std::memory_order_relaxed))
return false;
return get_tls_state();
}
// Call this from the multipy/torchdeploy top level
static void init_state();
private:
// This only flipped once from false to true during torchdeploy/multipy
// initialization, and never again.
static std::atomic<bool> haveState_;
static bool get_tls_state();
};
} // namespace impl
} // namespace c10

View File

@ -27,6 +27,13 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
PANIC(dispatch);
}
void python_op_registration_trampoline(
const c10::OperatorHandle& op,
c10::DispatchKey,
torch::jit::Stack* stack) const override {
PANIC(python_op_registration_trampoline);
}
void python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet,

View File

@ -141,6 +141,15 @@ struct C10_API PyInterpreterVTable {
virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const = 0;
// This is only invoked in the multipy/torchdeploy situation from
// pythonOpRegistrationTrampoline; this lets us get to the Python
// interpreter to actually find the appropriate Python op registration
// entry to call.
virtual void python_op_registration_trampoline(
const c10::OperatorHandle& op,
c10::DispatchKey,
torch::jit::Stack* stack) const = 0;
// Invoke the Python dispatcher to handle this call
virtual void python_dispatcher(
const c10::OperatorHandle& op,

View File

@ -5673,36 +5673,6 @@ class TestTorch(TestCase):
r"the unspecified dimension size -1 can be any value and is ambiguous"):
torch.randn(2, 0).unflatten(1, (2, -1, 0))
def test_pytorch_library_disabled_env(self):
import subprocess
env = os.environ.copy()
env['PYTORCH_DISABLE_LIBRARY'] = '1'
try:
subprocess.check_output([sys.executable, '-c', 'import torch'], env=env)
except subprocess.CalledProcessError as e:
raise RuntimeError("Could not 'import torch' with PYTORCH_DISABLE_LIBRARY=0") from e
# Test that warnings generated from C++ are translated to the correct type
def test_warn_types(self):
test_cases = [
# function, warning type, message
(torch._C._warn, UserWarning, r"Test message for TORCH_WARN"),
(torch._C._warn_deprecation, DeprecationWarning, r"Test message for TORCH_WARN_DEPRECATION"),
]
for fn, warning_type, message in test_cases:
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
warnings.filterwarnings('always', category=warning_type)
fn()
self.assertEqual(len(w), 1, msg=f'{warning_type} not raised')
warning = w[0].message
self.assertTrue(isinstance(warning, warning_type), msg=f'{warning_type} not raised')
self.assertTrue(re.search(
message,
str(warning)))
def test_structseq_repr(self):
a = torch.arange(250).reshape(5, 5, 10)
expected = """

View File

@ -1037,7 +1037,6 @@ class _DispatchModule:
def def_name_t_t(self, name: str, dispatch: str, debug: str = "default_def_name_t_t") -> _DispatchModule: ...
def def_schema_t_t(self, schema: str, dispatch: str, alias: str, debug: str = "default_def_schema_t_t") -> _DispatchModule: ...
def impl_t_t(self, name: str, dispatch: str, debug: str = "impl_t_t") -> _DispatchModule: ...
def impl_tt_t(self, name: str, dispatch: str, debug: str = "impl_tt_t") -> _DispatchModule: ...
def impl(self, name: str, dispatch: str, func: Callable) -> _DispatchModule: ...
def define(self, schema: str, alias: str = "") -> _DispatchModule: ...
def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ...

View File

@ -1021,10 +1021,9 @@ def _register_device_module(device_type, module):
# expose return_types
from . import return_types
if sys.executable != 'torch_deploy' and os.environ.get('PYTORCH_DISABLE_LIBRARY', "0") == "0":
from . import library
if not TYPE_CHECKING:
from . import _meta_registrations
from . import library
if not TYPE_CHECKING:
from . import _meta_registrations
# Enable CUDA Sanitizer
if 'TORCH_CUDA_SANITIZER' in os.environ:

View File

@ -86,8 +86,7 @@ def make_dual(tensor, tangent, *, level=None):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
# Currently broken for 3.11, see https://github.com/pytorch/pytorch/issues/85506
if (os.environ.get("PYTORCH_JIT", "1" if sys.version_info < (3, 11) else "0") == "1" and
__debug__ and
os.environ.get('PYTORCH_DISABLE_LIBRARY', "0") == "0"):
__debug__):
from torch._decomp import decompositions_for_jvp # noqa: F401
if level is None:

View File

@ -1,8 +1,10 @@
#include <ATen/NamedTensorUtils.h>
#include <ATen/core/PythonFallbackKernel.h>
#include <ATen/core/PythonOpRegistrationTrampoline.h>
#include <c10/core/DeviceType.h>
#include <c10/core/SafePyObject.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/core/impl/HermeticPyObjectTLS.h>
#include <c10/core/impl/PythonDispatcherTLS.h>
#include <c10/util/DeadlockDetection.h>
#include <c10/util/irange.h>
@ -31,6 +33,7 @@
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_dispatch.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_memoryformats.h>
@ -219,6 +222,14 @@ struct ConcretePyInterpreterVTable final
const c10::OperatorHandle& op,
c10::DispatchKeySet,
torch::jit::Stack* stack) const override;
// NB: this is defined in python_dispatch.cpp
void python_op_registration_trampoline(
const c10::OperatorHandle& op,
c10::DispatchKey key,
torch::jit::Stack* stack) const override {
torch::impl::dispatch::python_op_registration_trampoline_impl(
op, key, stack);
}
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override;
bool is_strides_like(const TensorImpl* self, at::MemoryFormat) const override;
@ -294,6 +305,10 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor)
// THPVariable_clear).
// 2. We are decref-ing some other Python object. We don't do
// PyObject resurrection on non-Tensors, so we just carry on as usual
if (is_tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!c10::impl::HermeticPyObjectTLS::get_state());
}
if (is_tensor && Py_REFCNT(pyobj) > 1) {
// It's still alive! This can happen if a weak ref resurrected
// the PyObject without flipping ownership. At this point it is
@ -314,7 +329,10 @@ class PyInterpreterHolder {
public:
PyInterpreterHolder()
: impl_(new c10::impl::PyInterpreter(
ConcretePyInterpreterVTable::instance())) {}
ConcretePyInterpreterVTable::instance())) {
is_main_interpreter_ =
at::impl::PythonOpRegistrationTrampoline::registerInterpreter(impl_);
}
// NB: intentionally leaks the PyInterpreter, as there may still be
// references to it that are live, living in objects that aren't being
// destructed while Python is being cleaned up.
@ -324,9 +342,13 @@ class PyInterpreterHolder {
c10::impl::PyInterpreter* get() const noexcept {
return impl_;
}
bool is_main_interpreter() const noexcept {
return is_main_interpreter_;
}
private:
c10::impl::PyInterpreter* impl_;
bool is_main_interpreter_;
};
PyInterpreterHolder self_interpreter;
@ -352,6 +374,10 @@ c10::impl::PyInterpreter* getPyInterpreter() {
return self_interpreter.get();
}
bool isMainPyInterpreter() {
return self_interpreter.is_main_interpreter();
}
std::string ConcretePyInterpreterVTable::name() const {
std::stringstream ss;
ss << getPyInterpreter();
@ -416,6 +442,13 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
Py_RETURN_NONE;
}
if (c10::impl::HermeticPyObjectTLS::get_state()) {
return THPVariable_NewWithVar(
(PyTypeObject*)THPVariableClass,
std::move(var),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
}
c10::optional<PyObject*> mb_obj =
var.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get());
c10::impl::PyInterpreterStatus status;
@ -489,6 +522,11 @@ bool isResurrectable(THPVariable* self) {
return false;
}
auto const& tensor = THPVariable_Unpack(self);
// Check if this is hermetic. If it is, no resurrection.
if (tensor.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get()) !=
c10::make_optional((PyObject*)self)) {
return false;
}
if (!tensor.defined() || tensor.use_count() <= 1) {
return false;
}
@ -531,6 +569,7 @@ static bool THPVariable_tryResurrect(THPVariable* self) {
// Flip THPVariable to be non-owning
// (near use-after-free miss here: fresh MaybeOwned is created breaking
// reference on Tensor in struct BEFORE we overwrite the old one)
TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state());
self->cdata = MaybeOwned<Variable>::borrowed(tensor);
// NB: At this point, tensor *could* be dead (e.g., some other C++ thread
@ -582,7 +621,9 @@ static int THPVariable_clear(THPVariable* self) {
// unsafeIsBorrowed() is TRUE. We're deallocating the PyObject
// because Tensor asked us to (it's already destructing).
if (!self->cdata.unsafeIsBorrowed()) {
if (!self->cdata.unsafeIsBorrowed() &&
tensor.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get()) ==
c10::make_optional((PyObject*)self)) {
// TODO: empirically, on OS X this assert appears to be untrue
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
// distributed/rpc/test_process_group_agent.py
@ -1887,11 +1928,27 @@ static PyObject* THPVariable_NewWithVar(
auto v = (THPVariable*)obj;
// TODO: named constructor to avoid default initialization
new (&v->cdata) MaybeOwned<Variable>();
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
const auto& var = THPVariable_Unpack(v);
var.unsafeGetTensorImpl()->init_pyobj(self_interpreter.get(), obj, status);
if (check_has_torch_dispatch(obj)) {
var.unsafeGetTensorImpl()->set_python_dispatch(true);
if (c10::impl::HermeticPyObjectTLS::get_state()) {
// Do NOT initialize pyobj field on the tensor, you own the C++
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
TORCH_INTERNAL_ASSERT(
!check_has_torch_dispatch(obj),
"While HermeticPyObject was enabled, we attempted to create a tensor "
"subclass with __torch_dispatch__. This violates the invariant that "
"operations in HermeticPyObject have equivalent C++ implementations. "
"If your operator registered from Python operator registration isn't "
"doing anything strange, there may be an internal PyTorch bug involving "
"not appropriately disabling TorchDispatchMode before executing "
"Python op registration.");
} else {
// Normal codepath
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
const auto& var = THPVariable_Unpack(v);
var.unsafeGetTensorImpl()->init_pyobj(
self_interpreter.get(), obj, status);
if (check_has_torch_dispatch(obj)) {
var.unsafeGetTensorImpl()->set_python_dispatch(true);
}
}
}
return obj;

View File

@ -69,6 +69,7 @@ inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
}
TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
TORCH_PYTHON_API bool isMainPyInterpreter();
std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
const c10::OperatorHandle& op,

View File

@ -4,6 +4,7 @@
#include <ATen/ATen.h>
#include <ATen/FuncTorchTLS.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/core/PythonOpRegistrationTrampoline.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
@ -11,6 +12,7 @@
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <c10/util/flat_hash_map.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/csrc/utils/pybind.h>
@ -23,6 +25,14 @@ namespace torch {
namespace impl {
namespace dispatch {
// NB: I'd like to index this on OperatorHandle, but I can't, as I can't
// guarantee that the main interpreter has finish doing all registrations before
// the other interpreters start banging on it
static ska::flat_hash_map<
c10::OperatorName,
ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
python_registrations_;
torch::Library::Kind parseKind(const std::string& k) {
static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
{"DEF", torch::Library::DEF},
@ -58,19 +68,101 @@ inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
}
}
struct EnableHermeticPyObject {
EnableHermeticPyObject()
: old_(c10::impl::HermeticPyObjectTLS::get_state()),
old_excluded_python_(
c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
old_python_(
c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
at::DispatchKey::PythonTLSSnapshot)) {
c10::impl::HermeticPyObjectTLS::set_state(true);
c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
at::DispatchKey::PythonTLSSnapshot, false);
}
~EnableHermeticPyObject() {
c10::impl::HermeticPyObjectTLS::set_state(old_);
c10::impl::tls_set_dispatch_key_excluded(
at::DispatchKey::Python, old_excluded_python_);
c10::impl::tls_set_dispatch_key_included(
at::DispatchKey::Python, old_python_);
c10::impl::tls_set_dispatch_key_included(
at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
}
bool old_;
bool old_excluded_python_;
bool old_python_;
bool old_python_snapshot_;
};
class PythonKernelHolder : public c10::OperatorKernel {
c10::SafePyObject func_;
c10::DispatchKey dispatch_key_;
public:
PythonKernelHolder(py::object func)
: func_(func.release().ptr(), getPyInterpreter()) {}
PythonKernelHolder(py::object func, c10::DispatchKey dispatch_key)
: func_(func.release().ptr(), getPyInterpreter()),
dispatch_key_(dispatch_key) {}
void operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
// Figure out if we can handle it hermetically, or if we have
// to double dispatch
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
if (mode_stack_len > 0) {
const auto& cur_torch_dispatch_mode_state =
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
cur_torch_dispatch_mode_state->pyinterpreter()
->python_op_registration_trampoline(op, dispatch_key_, stack);
return;
}
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
// Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
// means it's a nontrivial tensor subclass)
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
if (ivalue.isTensor()) {
auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter &&
ivalue.unsafeToTensorImpl()->key_set().has(
at::DispatchKey::Python)) {
(*interpreter)
->python_op_registration_trampoline(op, dispatch_key_, stack);
return;
}
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
// NB: use toListRef as it doesn't induce refcount bumps
// (toTensorListRef is not a thing)
for (const auto& nv : ivalue.toListRef()) {
if (nv.isNone()) {
continue;
}
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter &&
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
(*interpreter)
->python_op_registration_trampoline(op, dispatch_key_, stack);
return;
}
}
}
}
// Nothing requires the operator to be homed to a specific interpreter, so
// run it on the current interpreter
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
py::gil_scoped_acquire g;
EnableHermeticPyObject g2;
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto obj = py::reinterpret_steal<py::object>(PyObject_Call(
func_.ptr(getPyInterpreter()),
@ -83,6 +175,14 @@ class PythonKernelHolder : public c10::OperatorKernel {
}
};
torch::_RegisterOrVerify register_or_verify() {
if (isMainPyInterpreter()) {
return torch::_RegisterOrVerify::REGISTER;
} else {
return torch::_RegisterOrVerify::VERIFY;
}
}
void initDispatchBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
@ -91,9 +191,12 @@ void initDispatchBindings(PyObject* module) {
// TODO: figure out how to do chaining
py::class_<torch::Library>(m, "_DispatchModule")
// Some of these APIs are only for testing and do not work in multipy
// environment
.def(
"def_",
[](py::object self, const char* schema, const char* alias) {
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
self.cast<torch::Library&>().def(
torch::schema(schema, parseAliasAnalysisKind(alias)));
return self;
@ -107,6 +210,7 @@ void initDispatchBindings(PyObject* module) {
.def(
"def_legacy",
[](py::object self, const char* schema) {
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
return self;
},
@ -126,6 +230,7 @@ void initDispatchBindings(PyObject* module) {
const char* name,
const char* dispatch,
const char* debug) {
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
self.cast<torch::Library&>().def(
name, dispatch_str(dispatch, [](const at::Tensor& a) {
return a;
@ -143,6 +248,7 @@ void initDispatchBindings(PyObject* module) {
const char* dispatch,
const char* alias,
const char* debug) {
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
self.cast<torch::Library&>().def(
torch::schema(schema, parseAliasAnalysisKind(alias)),
dispatch_str(dispatch, [](const at::Tensor& a) {
@ -163,6 +269,7 @@ void initDispatchBindings(PyObject* module) {
const char* name,
const char* dispatch,
const char* debug) {
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
self.cast<torch::Library&>().impl(
name, dispatch_str(dispatch, [](const at::Tensor& a) {
return a;
@ -173,38 +280,26 @@ void initDispatchBindings(PyObject* module) {
py::arg("name"),
py::arg("dispatch") = "",
py::arg("debug") = "impl_t_t")
.def(
"impl_tt_t",
[](py::object self,
const char* name,
const char* dispatch,
const char* debug) {
self.cast<torch::Library&>().impl(
name,
dispatch_str(
dispatch,
[](const at::Tensor& a, const at::Tensor& b) { return a; })
.debug(debug));
return self;
},
"",
py::arg("name"),
py::arg("dispatch") = "",
py::arg("debug") = "")
.def(
"impl",
[](py::object self,
const char* name,
const char* dispatch,
// TODO: empty string no longer works
c10::DispatchKey dispatch,
py::object func) {
HANDLE_TH_ERRORS
self.cast<torch::Library&>().impl(
auto& lib = self.cast<torch::Library&>();
lib.impl(
name,
dispatch_str(
torch::dispatch(
dispatch,
CppFunction::makeFromBoxedFunctor(
std::make_unique<PythonKernelHolder>(
std::move(func)))));
std::make_unique<PythonKernelHolder>(func, dispatch))),
register_or_verify());
python_registrations_[lib._resolve(name)].insert_or_assign(
dispatch,
std::make_shared<c10::SafePyObject>(
func.release().ptr(), getPyInterpreter()));
END_HANDLE_TH_ERRORS_PYBIND
},
"",
@ -214,8 +309,11 @@ void initDispatchBindings(PyObject* module) {
.def(
"define",
[](py::object self, const char* schema, const char* alias_analysis) {
auto parsed_schema =
torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
self.cast<torch::Library&>().def(
torch::schema(schema, parseAliasAnalysisKind(alias_analysis)));
std::move(parsed_schema), {}, register_or_verify());
// TODO: this is dumb, had to make a second copy
return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
.name();
},
@ -225,6 +323,7 @@ void initDispatchBindings(PyObject* module) {
.def(
"fallback_fallthrough",
[](py::object self, const char* dispatch) {
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
self.cast<torch::Library&>().fallback(
dispatch_str(dispatch, CppFunction::makeFallthrough()));
return self;
@ -488,6 +587,9 @@ void initDispatchBindings(PyObject* module) {
},
py::arg("dispatch_key") = static_cast<const char*>(""));
m.def(
"_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
m.def("_are_functorch_transforms_active", []() {
auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
return (
@ -496,6 +598,26 @@ void initDispatchBindings(PyObject* module) {
});
}
// TODO: dedupe with the kernel
void python_op_registration_trampoline_impl(
const c10::OperatorHandle& op,
c10::DispatchKey key,
torch::jit::Stack* stack) {
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
py::gil_scoped_acquire g;
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
const auto& func = python_registrations_[op.operator_name()][key];
TORCH_INTERNAL_ASSERT(func != nullptr);
auto* pyobj = func->ptr(getPyInterpreter());
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
auto obj = py::reinterpret_steal<py::object>(
PyObject_Call(pyobj, args_kwargs.first.ptr(), args_kwargs.second.ptr()));
if (!obj) {
throw python_error();
}
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
}
} // namespace dispatch
} // namespace impl
} // namespace torch

View File

@ -7,6 +7,11 @@ namespace dispatch {
void initDispatchBindings(PyObject* module);
}
void python_op_registration_trampoline_impl(
const c10::OperatorHandle& op,
c10::DispatchKey key,
torch::jit::Stack* stack);
} // namespace dispatch
} // namespace impl
} // namespace torch

View File

@ -86,6 +86,12 @@ namespace torch {
struct NoInferSchemaTag {};
#endif
// For multipy/torchdeploy use case
enum class _RegisterOrVerify {
REGISTER,
VERIFY
};
template <class CurClass>
class class_;
@ -591,9 +597,9 @@ class TORCH_API Library final {
/// ```
template <typename Schema>
Library& def(Schema&& raw_schema, const std::vector<at::Tag>& tags = {}) & {
Library& def(Schema&& raw_schema, const std::vector<at::Tag>& tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
return _def(std::move(s), nullptr, tags);
return _def(std::move(s), nullptr, tags, rv);
}
/// Define an operator for a schema and then register an implementation for
/// it. This is typically what you would use if you aren't planning
@ -644,7 +650,7 @@ class TORCH_API Library final {
/// }
/// ```
template <typename Name, typename Func>
Library& impl(Name name, Func&& raw_f) & {
Library& impl(Name name, Func&& raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
// TODO: need to raise an error when you impl a function that has a
// catch all def
#if defined C10_MOBILE
@ -652,7 +658,7 @@ class TORCH_API Library final {
#else
CppFunction f(std::forward<Func>(raw_f));
#endif
return _impl(name, std::move(f));
return _impl(name, std::move(f), rv);
}
#if defined C10_MOBILE
@ -673,6 +679,10 @@ class TORCH_API Library final {
}
#endif
// Helper for getting an OperatorName for a const char*. You probably
// don't need this.
c10::OperatorName _resolve(const char* name) const;
/// \private
///
/// Convenience overload for directly specifying the dispatch key when
@ -809,12 +819,17 @@ class TORCH_API Library final {
Library& _def(
c10::FunctionSchema&& schema,
c10::OperatorName* out_name = nullptr,
const std::vector<at::Tag>& tags = {}) &;
const std::vector<at::Tag>& tags = {},
_RegisterOrVerify rv = _RegisterOrVerify::REGISTER
) &;
Library& _def(
c10::either<c10::OperatorName, c10::FunctionSchema>&&,
CppFunction&& f) &;
Library& _impl(const char* name, CppFunction&& f) &;
Library& _impl(const char* name, CppFunction&& f,
_RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &;
Library& _fallback(CppFunction&& f) &;
at::OperatorName _parseNameForLib(const char* name_str) const;
};
namespace detail {

View File

@ -2,7 +2,6 @@ from ._ops import OpOverload
from typing import Set
import traceback
import torch
import os
__all__ = ['Library', 'impl', 'define']
@ -30,9 +29,6 @@ class Library:
dispatch_key: PyTorch dispatch key (default: "")
"""
def __init__(self, ns, kind, dispatch_key=""):
if os.environ.get('PYTORCH_DISABLE_LIBRARY', "0") == "1":
raise RuntimeError("Trying to use torch.library in an environment where it is disabled")
if kind != "IMPL" and kind != "DEF":
raise ValueError("Unsupported kind: ", kind)
@ -126,7 +122,8 @@ class Library:
" Instead we should let the operator decompose, and ensure that we have meta kernels"
" for the base ops that it decomposes into.")
self.m.impl(name, dispatch_key, fn)
self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn)
_impls.add(key)
self._op_impls.add(key)