mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Make Python op registration work with torchdeploy/multipy (#87162)
See strategy at PythonOpRegistrationTrampoline.cpp for the big picture. Along the way, I made OperatorHandle support == and hashing, and slightly changed the low level python_dispatch impl API to disallow empty strings for dispatch key, which had the knock on effect of requiring us to explicitly make sure we pass in CompositeImplicitAutograd if we would have passed in "" (I didn't apply this to the rest of the file because I'm lazy.) Test strategy is we delete the logic for preventing Python op registrations in torch from being skipped in a torchdeploy context and show CI still works. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/87162 Approved by: https://github.com/anjali411, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
2f296cfdbb
commit
f884e817d4
28
aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp
Normal file
28
aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp
Normal file
@ -0,0 +1,28 @@
|
||||
#include <ATen/core/PythonOpRegistrationTrampoline.h>
|
||||
|
||||
namespace at {
|
||||
namespace impl {
|
||||
|
||||
// The strategy is that all python interpreters attempt to register themselves
|
||||
// as the main interpreter, but only one wins. Only that interpreter is
|
||||
// allowed to interact with the C++ dispatcher. Furthermore, when we execute
|
||||
// logic on that interpreter, we do so hermetically, never setting pyobj field
|
||||
// on Tensor.
|
||||
|
||||
std::atomic<c10::impl::PyInterpreter*> PythonOpRegistrationTrampoline::interpreter_{nullptr};
|
||||
|
||||
bool PythonOpRegistrationTrampoline::registerInterpreter(c10::impl::PyInterpreter* interp) {
|
||||
c10::impl::PyInterpreter* expected = nullptr;
|
||||
interpreter_.compare_exchange_strong(expected, interp);
|
||||
if (expected != nullptr) {
|
||||
// This is the second (or later) Python interpreter, which means we need
|
||||
// non-trivial hermetic PyObject TLS
|
||||
c10::impl::HermeticPyObjectTLS::init_state();
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace at
|
18
aten/src/ATen/core/PythonOpRegistrationTrampoline.h
Normal file
18
aten/src/ATen/core/PythonOpRegistrationTrampoline.h
Normal file
@ -0,0 +1,18 @@
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
// TODO: this can probably live in c10
|
||||
|
||||
namespace at {
|
||||
namespace impl {
|
||||
|
||||
class TORCH_API PythonOpRegistrationTrampoline final {
|
||||
static std::atomic<c10::impl::PyInterpreter*> interpreter_;
|
||||
|
||||
public:
|
||||
// Returns true if you successfully registered yourself (that means
|
||||
// you are in the hot seat for doing the operator registrations!)
|
||||
static bool registerInterpreter(c10::impl::PyInterpreter*);
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
} // namespace at
|
@ -1,6 +1,7 @@
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <list>
|
||||
#include <sstream>
|
||||
#include <chrono>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -50,7 +51,9 @@ Dispatcher::Dispatcher()
|
||||
, operatorLookupTable_()
|
||||
, backendFallbackKernels_()
|
||||
, listeners_(std::make_unique<detail::RegistrationListenerList>())
|
||||
, mutex_() {}
|
||||
, mutex_()
|
||||
, cond_var_()
|
||||
{}
|
||||
|
||||
Dispatcher::~Dispatcher() = default;
|
||||
|
||||
@ -69,6 +72,41 @@ c10::optional<OperatorHandle> Dispatcher::findOp(const OperatorName& overload_na
|
||||
});
|
||||
}
|
||||
|
||||
// NB: If you add more waitFor* implementations, you also have to add
|
||||
// appropriate notify_all() calls to the relevant register calls
|
||||
|
||||
void Dispatcher::waitForDef(const FunctionSchema& schema) {
|
||||
using namespace std::chrono_literals;
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
bool r = cond_var_.wait_for(lock, 2s, [&]{
|
||||
return findOp(schema.operator_name()) != c10::nullopt;
|
||||
});
|
||||
TORCH_INTERNAL_ASSERT(r,
|
||||
"Expected main interpreter to define ", schema.operator_name(),
|
||||
", but this didn't happen within timeout. Are you trying to load "
|
||||
"different models in the same torchdeploy/multipy instance? You "
|
||||
"must warmup each interpreter identically, e.g., import all "
|
||||
"the same dependencies.");
|
||||
}
|
||||
|
||||
void Dispatcher::waitForImpl(const OperatorName& op_name, c10::optional<c10::DispatchKey> maybe_dk) {
|
||||
using namespace std::chrono_literals;
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto dk = maybe_dk.value_or(DispatchKey::CompositeImplicitAutograd);
|
||||
auto op = findOrRegisterName_(op_name);
|
||||
bool r = cond_var_.wait_for(lock, 2s, [&]{
|
||||
// NB: this is slightly unsound for overrides, but overrides are
|
||||
// funny business anyway
|
||||
return op.hasKernelForDispatchKey(dk);
|
||||
});
|
||||
TORCH_INTERNAL_ASSERT(r,
|
||||
"Expected main interpreter to implement ", dk, " for ", op_name,
|
||||
", but this didn't happen within timeout. Are you trying to load "
|
||||
"different models in the same torchdeploy/multipy instance? You "
|
||||
"must warmup each interpreter identically, e.g., import all "
|
||||
"the same dependencies.");
|
||||
}
|
||||
|
||||
c10::optional<OperatorHandle> Dispatcher::findSchema(const OperatorName& overload_name) {
|
||||
auto it = findOp(overload_name);
|
||||
if (it.has_value()) {
|
||||
@ -175,6 +213,8 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin
|
||||
++op.operatorDef_->def_count;
|
||||
++op.operatorDef_->def_and_impl_count;
|
||||
|
||||
cond_var_.notify_all();
|
||||
|
||||
return RegistrationHandleRAII([this, op, op_name] {
|
||||
deregisterDef_(op, op_name);
|
||||
});
|
||||
@ -227,6 +267,8 @@ RegistrationHandleRAII Dispatcher::registerImpl(
|
||||
|
||||
++op.operatorDef_->def_and_impl_count;
|
||||
|
||||
cond_var_.notify_all();
|
||||
|
||||
return RegistrationHandleRAII([this, op, op_name, dispatch_key, handle] {
|
||||
deregisterImpl_(op, op_name, dispatch_key, handle);
|
||||
});
|
||||
@ -249,6 +291,7 @@ RegistrationHandleRAII Dispatcher::registerName(OperatorName op_name) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto op = findOrRegisterName_(op_name);
|
||||
++op.operatorDef_->def_and_impl_count;
|
||||
|
||||
return RegistrationHandleRAII(
|
||||
[this, op, op_name] { deregisterName_(op, op_name); });
|
||||
}
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <c10/util/LeftRight.h>
|
||||
#include <list>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/core/grad_mode.h>
|
||||
@ -182,6 +183,9 @@ public:
|
||||
return backendFallbackKernels_[dispatch_ix].kernel.isValid();
|
||||
}
|
||||
|
||||
// Used by torchdeploy/multipy for multiple interpreters racing.
|
||||
void waitForDef(const FunctionSchema& schema);
|
||||
void waitForImpl(const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key);
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
//
|
||||
@ -307,7 +311,23 @@ private:
|
||||
std::array<impl::AnnotatedKernel, num_runtime_entries> backendFallbackKernels_;
|
||||
|
||||
std::unique_ptr<detail::RegistrationListenerList> listeners_;
|
||||
|
||||
// This mutex protects concurrent access to the dispatcher
|
||||
std::mutex mutex_;
|
||||
|
||||
// This condition variable gets notified whenever we add a new def/impl to the
|
||||
// dispatch table. This is primarily used by multipy/torchdeploy, when
|
||||
// we have multiple interpreters trying to register to the dispatch table.
|
||||
// In this situation, whenever the non-primary interpreter would have tried
|
||||
// to register to the dispatch table, instead it will check to see if the
|
||||
// expected registration has already been made, and if it hasn't, wait on
|
||||
// this condition variable to see if it was just racing with the primary
|
||||
// interpreter.
|
||||
//
|
||||
// We expect it to be rare for there to be any waiters on this condition
|
||||
// variable. This is mostly just to help give better diagnostics if
|
||||
// something goes horribly wrong
|
||||
std::condition_variable cond_var_;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -316,6 +336,8 @@ private:
|
||||
* to lookup a kernel for a certain set of arguments.
|
||||
*/
|
||||
class TORCH_API OperatorHandle {
|
||||
template <typename T> friend class std::hash;
|
||||
|
||||
public:
|
||||
OperatorHandle(OperatorHandle&&) noexcept = default;
|
||||
OperatorHandle& operator=(OperatorHandle&&) noexcept = default;
|
||||
@ -411,6 +433,14 @@ public:
|
||||
return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor);
|
||||
}
|
||||
|
||||
bool operator==(const OperatorHandle& other) const {
|
||||
return operatorDef_ == other.operatorDef_;
|
||||
}
|
||||
|
||||
bool operator!=(const OperatorHandle& other) const {
|
||||
return operatorDef_ != other.operatorDef_;
|
||||
}
|
||||
|
||||
private:
|
||||
explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
|
||||
: operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {}
|
||||
@ -695,3 +725,14 @@ inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
struct hash<c10::OperatorHandle> {
|
||||
size_t operator()(c10::OperatorHandle op) const noexcept {
|
||||
return std::hash<void*>{}(static_cast<void*>(op.operatorDef_));
|
||||
}
|
||||
};
|
||||
|
||||
} // hamespace std
|
||||
|
@ -89,7 +89,7 @@ Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, c
|
||||
// merge everything
|
||||
|
||||
#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
|
||||
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector<at::Tag>& tags) & {
|
||||
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector<at::Tag>& tags, _RegisterOrVerify rv) & {
|
||||
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
|
||||
DEF_PRELUDE,
|
||||
"Cannot define an operator inside of a ", toString(kind_), " block. "
|
||||
@ -125,13 +125,20 @@ Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name
|
||||
if (out_name) {
|
||||
*out_name = schema.operator_name(); // copy!
|
||||
}
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerDef(
|
||||
std::move(schema),
|
||||
debugString(file_, line_),
|
||||
tags
|
||||
)
|
||||
);
|
||||
switch (rv) {
|
||||
case _RegisterOrVerify::REGISTER:
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerDef(
|
||||
std::move(schema),
|
||||
debugString(file_, line_),
|
||||
tags
|
||||
)
|
||||
);
|
||||
break;
|
||||
case _RegisterOrVerify::VERIFY:
|
||||
c10::Dispatcher::singleton().waitForDef(schema);
|
||||
break;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
#undef DEF_PRELUDE
|
||||
@ -174,11 +181,10 @@ Library& Library::_def(c10::either<c10::OperatorName, c10::FunctionSchema>&& nam
|
||||
}
|
||||
|
||||
#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
|
||||
Library& Library::_impl(const char* name_str, CppFunction&& f) & {
|
||||
at::OperatorName Library::_parseNameForLib(const char* name_str) const {
|
||||
auto name = torch::jit::parseName(name_str);
|
||||
auto ns_opt = name.getNamespace();
|
||||
// This is kind of similar to the checking in def(), but the error
|
||||
// messages are a little different for this call site
|
||||
// This is a copy paste of Library::_impl
|
||||
if (ns_opt.has_value()) {
|
||||
// See Note [Redundancy in registration code is OK]
|
||||
TORCH_CHECK(*ns_opt == *ns_,
|
||||
@ -193,6 +199,11 @@ Library& Library::_impl(const char* name_str, CppFunction&& f) & {
|
||||
bool b = name.setNamespaceIfNotSet(ns_->c_str());
|
||||
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
Library& Library::_impl(const char* name_str, CppFunction&& f, _RegisterOrVerify rv) & {
|
||||
at::OperatorName name = _parseNameForLib(name_str);
|
||||
// See Note [Redundancy in registration code is OK]
|
||||
TORCH_CHECK(!(f.dispatch_key_.has_value() &&
|
||||
dispatch_key_.has_value() &&
|
||||
@ -205,19 +216,30 @@ Library& Library::_impl(const char* name_str, CppFunction&& f) & {
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerImpl(
|
||||
std::move(name),
|
||||
dispatch_key,
|
||||
std::move(f.func_),
|
||||
// NOLINTNEXTLINE(performance-move-const-arg)
|
||||
std::move(f.cpp_signature_),
|
||||
std::move(f.schema_),
|
||||
debugString(std::move(f.debug_), file_, line_)
|
||||
)
|
||||
);
|
||||
switch (rv) {
|
||||
case _RegisterOrVerify::REGISTER:
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerImpl(
|
||||
std::move(name),
|
||||
dispatch_key,
|
||||
std::move(f.func_),
|
||||
// NOLINTNEXTLINE(performance-move-const-arg)
|
||||
std::move(f.cpp_signature_),
|
||||
std::move(f.schema_),
|
||||
debugString(std::move(f.debug_), file_, line_)
|
||||
)
|
||||
);
|
||||
break;
|
||||
case _RegisterOrVerify::VERIFY:
|
||||
c10::Dispatcher::singleton().waitForImpl(name, dispatch_key);
|
||||
break;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
c10::OperatorName Library::_resolve(const char* name_str) const {
|
||||
return _parseNameForLib(name_str);
|
||||
}
|
||||
#undef IMPL_PRELUDE
|
||||
|
||||
Library& Library::_fallback(CppFunction&& f) & {
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/core/WrapDimMinimal.h>
|
||||
#include <c10/core/impl/HermeticPyObjectTLS.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <c10/core/impl/PyInterpreter.h>
|
||||
#include <c10/core/impl/SizesAndStrides.h>
|
||||
@ -2037,7 +2038,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return c10::nullopt;
|
||||
} else if (interpreter == self_interpreter) {
|
||||
// NB: pyobj_ could still be null!
|
||||
return c10::make_optional(_unchecked_untagged_pyobj());
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return c10::nullopt;
|
||||
} else {
|
||||
return c10::make_optional(_unchecked_untagged_pyobj());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
23
c10/core/impl/HermeticPyObjectTLS.cpp
Normal file
23
c10/core/impl/HermeticPyObjectTLS.cpp
Normal file
@ -0,0 +1,23 @@
|
||||
#include <c10/core/impl/HermeticPyObjectTLS.h>
|
||||
|
||||
namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
thread_local std::atomic<bool> hermeticPyObjectState{false};
|
||||
|
||||
std::atomic<bool> HermeticPyObjectTLS::haveState_{false};
|
||||
|
||||
void HermeticPyObjectTLS::set_state(bool state) {
|
||||
hermeticPyObjectState = state;
|
||||
}
|
||||
|
||||
bool HermeticPyObjectTLS::get_tls_state() {
|
||||
return hermeticPyObjectState;
|
||||
}
|
||||
|
||||
void HermeticPyObjectTLS::init_state() {
|
||||
haveState_ = true;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace c10
|
61
c10/core/impl/HermeticPyObjectTLS.h
Normal file
61
c10/core/impl/HermeticPyObjectTLS.h
Normal file
@ -0,0 +1,61 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <atomic>
|
||||
|
||||
namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
// This TLS controls whether or not we permanently associate PyObject
|
||||
// with Tensor the first time it is allocated. When hermetic PyObject
|
||||
// TLS is enabled (state is true), we DO NOT save PyObjects to Tensor,
|
||||
// meaning you get a distinct PyObject whenever you execute the code in
|
||||
// question.
|
||||
struct C10_API HermeticPyObjectTLS {
|
||||
static void set_state(bool state);
|
||||
static bool get_state() {
|
||||
// Hypothetical fastpath if torchdeploy/multipy isn't used. Per
|
||||
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
|
||||
// this qualifies relaxed access because it is a single-location data
|
||||
// structure (only the boolean here).
|
||||
//
|
||||
// Forgetting about data races for a moment, is there a logical race?
|
||||
//
|
||||
// - Boolean only ever transitions from false to true. So the
|
||||
// critical situation is when one interpreter is already running
|
||||
// when a second interpreter switches haveState from false to true.
|
||||
//
|
||||
// - The first interpreter is indifferent whether or not it sees
|
||||
// hasState true/false; obviously false works (this is what the
|
||||
// interpreter was previously using; more directly, the interpreter
|
||||
// calls into itself as the handler, so being hermetic is not
|
||||
// required), and true simply means serviced python operator calls will
|
||||
// be hermetic; in these cases it is expected to be functionally
|
||||
// equivalent.
|
||||
//
|
||||
// - The second interpreter MUST see hasState true (as its requests will
|
||||
// be forwarded to the first interpreter), but it is assumed that there
|
||||
// is a synchronization between the interpreter initialization, and
|
||||
// when we actually perform operations, so it is guaranteed to see
|
||||
// hasState true.
|
||||
//
|
||||
// QED.
|
||||
//
|
||||
// This fastpath is currently disabled so that we can more easily test that
|
||||
// hermetic mode works correctly even on stock build of PyTorch.
|
||||
if (false && !haveState_.load(std::memory_order_relaxed))
|
||||
return false;
|
||||
return get_tls_state();
|
||||
}
|
||||
// Call this from the multipy/torchdeploy top level
|
||||
static void init_state();
|
||||
|
||||
private:
|
||||
// This only flipped once from false to true during torchdeploy/multipy
|
||||
// initialization, and never again.
|
||||
static std::atomic<bool> haveState_;
|
||||
static bool get_tls_state();
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
} // namespace c10
|
@ -27,6 +27,13 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
PANIC(dispatch);
|
||||
}
|
||||
|
||||
void python_op_registration_trampoline(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKey,
|
||||
torch::jit::Stack* stack) const override {
|
||||
PANIC(python_op_registration_trampoline);
|
||||
}
|
||||
|
||||
void python_dispatcher(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet,
|
||||
|
@ -141,6 +141,15 @@ struct C10_API PyInterpreterVTable {
|
||||
virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
|
||||
const = 0;
|
||||
|
||||
// This is only invoked in the multipy/torchdeploy situation from
|
||||
// pythonOpRegistrationTrampoline; this lets us get to the Python
|
||||
// interpreter to actually find the appropriate Python op registration
|
||||
// entry to call.
|
||||
virtual void python_op_registration_trampoline(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKey,
|
||||
torch::jit::Stack* stack) const = 0;
|
||||
|
||||
// Invoke the Python dispatcher to handle this call
|
||||
virtual void python_dispatcher(
|
||||
const c10::OperatorHandle& op,
|
||||
|
@ -5673,36 +5673,6 @@ class TestTorch(TestCase):
|
||||
r"the unspecified dimension size -1 can be any value and is ambiguous"):
|
||||
torch.randn(2, 0).unflatten(1, (2, -1, 0))
|
||||
|
||||
def test_pytorch_library_disabled_env(self):
|
||||
import subprocess
|
||||
env = os.environ.copy()
|
||||
env['PYTORCH_DISABLE_LIBRARY'] = '1'
|
||||
try:
|
||||
subprocess.check_output([sys.executable, '-c', 'import torch'], env=env)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError("Could not 'import torch' with PYTORCH_DISABLE_LIBRARY=0") from e
|
||||
|
||||
# Test that warnings generated from C++ are translated to the correct type
|
||||
def test_warn_types(self):
|
||||
test_cases = [
|
||||
# function, warning type, message
|
||||
(torch._C._warn, UserWarning, r"Test message for TORCH_WARN"),
|
||||
(torch._C._warn_deprecation, DeprecationWarning, r"Test message for TORCH_WARN_DEPRECATION"),
|
||||
]
|
||||
|
||||
for fn, warning_type, message in test_cases:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.resetwarnings()
|
||||
warnings.filterwarnings('always', category=warning_type)
|
||||
fn()
|
||||
|
||||
self.assertEqual(len(w), 1, msg=f'{warning_type} not raised')
|
||||
warning = w[0].message
|
||||
self.assertTrue(isinstance(warning, warning_type), msg=f'{warning_type} not raised')
|
||||
self.assertTrue(re.search(
|
||||
message,
|
||||
str(warning)))
|
||||
|
||||
def test_structseq_repr(self):
|
||||
a = torch.arange(250).reshape(5, 5, 10)
|
||||
expected = """
|
||||
|
@ -1037,7 +1037,6 @@ class _DispatchModule:
|
||||
def def_name_t_t(self, name: str, dispatch: str, debug: str = "default_def_name_t_t") -> _DispatchModule: ...
|
||||
def def_schema_t_t(self, schema: str, dispatch: str, alias: str, debug: str = "default_def_schema_t_t") -> _DispatchModule: ...
|
||||
def impl_t_t(self, name: str, dispatch: str, debug: str = "impl_t_t") -> _DispatchModule: ...
|
||||
def impl_tt_t(self, name: str, dispatch: str, debug: str = "impl_tt_t") -> _DispatchModule: ...
|
||||
def impl(self, name: str, dispatch: str, func: Callable) -> _DispatchModule: ...
|
||||
def define(self, schema: str, alias: str = "") -> _DispatchModule: ...
|
||||
def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ...
|
||||
|
@ -1021,10 +1021,9 @@ def _register_device_module(device_type, module):
|
||||
|
||||
# expose return_types
|
||||
from . import return_types
|
||||
if sys.executable != 'torch_deploy' and os.environ.get('PYTORCH_DISABLE_LIBRARY', "0") == "0":
|
||||
from . import library
|
||||
if not TYPE_CHECKING:
|
||||
from . import _meta_registrations
|
||||
from . import library
|
||||
if not TYPE_CHECKING:
|
||||
from . import _meta_registrations
|
||||
|
||||
# Enable CUDA Sanitizer
|
||||
if 'TORCH_CUDA_SANITIZER' in os.environ:
|
||||
|
@ -86,8 +86,7 @@ def make_dual(tensor, tangent, *, level=None):
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
|
||||
# Currently broken for 3.11, see https://github.com/pytorch/pytorch/issues/85506
|
||||
if (os.environ.get("PYTORCH_JIT", "1" if sys.version_info < (3, 11) else "0") == "1" and
|
||||
__debug__ and
|
||||
os.environ.get('PYTORCH_DISABLE_LIBRARY', "0") == "0"):
|
||||
__debug__):
|
||||
from torch._decomp import decompositions_for_jvp # noqa: F401
|
||||
|
||||
if level is None:
|
||||
|
@ -1,8 +1,10 @@
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/core/PythonFallbackKernel.h>
|
||||
#include <ATen/core/PythonOpRegistrationTrampoline.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/core/impl/GPUTrace.h>
|
||||
#include <c10/core/impl/HermeticPyObjectTLS.h>
|
||||
#include <c10/core/impl/PythonDispatcherTLS.h>
|
||||
#include <c10/util/DeadlockDetection.h>
|
||||
#include <c10/util/irange.h>
|
||||
@ -31,6 +33,7 @@
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/pycfunction_helpers.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/python_dispatch.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/tensor_memoryformats.h>
|
||||
@ -219,6 +222,14 @@ struct ConcretePyInterpreterVTable final
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet,
|
||||
torch::jit::Stack* stack) const override;
|
||||
// NB: this is defined in python_dispatch.cpp
|
||||
void python_op_registration_trampoline(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKey key,
|
||||
torch::jit::Stack* stack) const override {
|
||||
torch::impl::dispatch::python_op_registration_trampoline_impl(
|
||||
op, key, stack);
|
||||
}
|
||||
|
||||
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override;
|
||||
bool is_strides_like(const TensorImpl* self, at::MemoryFormat) const override;
|
||||
@ -294,6 +305,10 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor)
|
||||
// THPVariable_clear).
|
||||
// 2. We are decref-ing some other Python object. We don't do
|
||||
// PyObject resurrection on non-Tensors, so we just carry on as usual
|
||||
if (is_tensor) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!c10::impl::HermeticPyObjectTLS::get_state());
|
||||
}
|
||||
if (is_tensor && Py_REFCNT(pyobj) > 1) {
|
||||
// It's still alive! This can happen if a weak ref resurrected
|
||||
// the PyObject without flipping ownership. At this point it is
|
||||
@ -314,7 +329,10 @@ class PyInterpreterHolder {
|
||||
public:
|
||||
PyInterpreterHolder()
|
||||
: impl_(new c10::impl::PyInterpreter(
|
||||
ConcretePyInterpreterVTable::instance())) {}
|
||||
ConcretePyInterpreterVTable::instance())) {
|
||||
is_main_interpreter_ =
|
||||
at::impl::PythonOpRegistrationTrampoline::registerInterpreter(impl_);
|
||||
}
|
||||
// NB: intentionally leaks the PyInterpreter, as there may still be
|
||||
// references to it that are live, living in objects that aren't being
|
||||
// destructed while Python is being cleaned up.
|
||||
@ -324,9 +342,13 @@ class PyInterpreterHolder {
|
||||
c10::impl::PyInterpreter* get() const noexcept {
|
||||
return impl_;
|
||||
}
|
||||
bool is_main_interpreter() const noexcept {
|
||||
return is_main_interpreter_;
|
||||
}
|
||||
|
||||
private:
|
||||
c10::impl::PyInterpreter* impl_;
|
||||
bool is_main_interpreter_;
|
||||
};
|
||||
PyInterpreterHolder self_interpreter;
|
||||
|
||||
@ -352,6 +374,10 @@ c10::impl::PyInterpreter* getPyInterpreter() {
|
||||
return self_interpreter.get();
|
||||
}
|
||||
|
||||
bool isMainPyInterpreter() {
|
||||
return self_interpreter.is_main_interpreter();
|
||||
}
|
||||
|
||||
std::string ConcretePyInterpreterVTable::name() const {
|
||||
std::stringstream ss;
|
||||
ss << getPyInterpreter();
|
||||
@ -416,6 +442,13 @@ PyObject* THPVariable_Wrap(at::TensorBase var) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)THPVariableClass,
|
||||
std::move(var),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
}
|
||||
|
||||
c10::optional<PyObject*> mb_obj =
|
||||
var.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get());
|
||||
c10::impl::PyInterpreterStatus status;
|
||||
@ -489,6 +522,11 @@ bool isResurrectable(THPVariable* self) {
|
||||
return false;
|
||||
}
|
||||
auto const& tensor = THPVariable_Unpack(self);
|
||||
// Check if this is hermetic. If it is, no resurrection.
|
||||
if (tensor.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get()) !=
|
||||
c10::make_optional((PyObject*)self)) {
|
||||
return false;
|
||||
}
|
||||
if (!tensor.defined() || tensor.use_count() <= 1) {
|
||||
return false;
|
||||
}
|
||||
@ -531,6 +569,7 @@ static bool THPVariable_tryResurrect(THPVariable* self) {
|
||||
// Flip THPVariable to be non-owning
|
||||
// (near use-after-free miss here: fresh MaybeOwned is created breaking
|
||||
// reference on Tensor in struct BEFORE we overwrite the old one)
|
||||
TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state());
|
||||
self->cdata = MaybeOwned<Variable>::borrowed(tensor);
|
||||
|
||||
// NB: At this point, tensor *could* be dead (e.g., some other C++ thread
|
||||
@ -582,7 +621,9 @@ static int THPVariable_clear(THPVariable* self) {
|
||||
// unsafeIsBorrowed() is TRUE. We're deallocating the PyObject
|
||||
// because Tensor asked us to (it's already destructing).
|
||||
|
||||
if (!self->cdata.unsafeIsBorrowed()) {
|
||||
if (!self->cdata.unsafeIsBorrowed() &&
|
||||
tensor.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get()) ==
|
||||
c10::make_optional((PyObject*)self)) {
|
||||
// TODO: empirically, on OS X this assert appears to be untrue
|
||||
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
|
||||
// distributed/rpc/test_process_group_agent.py
|
||||
@ -1887,11 +1928,27 @@ static PyObject* THPVariable_NewWithVar(
|
||||
auto v = (THPVariable*)obj;
|
||||
// TODO: named constructor to avoid default initialization
|
||||
new (&v->cdata) MaybeOwned<Variable>();
|
||||
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
|
||||
const auto& var = THPVariable_Unpack(v);
|
||||
var.unsafeGetTensorImpl()->init_pyobj(self_interpreter.get(), obj, status);
|
||||
if (check_has_torch_dispatch(obj)) {
|
||||
var.unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
// Do NOT initialize pyobj field on the tensor, you own the C++
|
||||
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!check_has_torch_dispatch(obj),
|
||||
"While HermeticPyObject was enabled, we attempted to create a tensor "
|
||||
"subclass with __torch_dispatch__. This violates the invariant that "
|
||||
"operations in HermeticPyObject have equivalent C++ implementations. "
|
||||
"If your operator registered from Python operator registration isn't "
|
||||
"doing anything strange, there may be an internal PyTorch bug involving "
|
||||
"not appropriately disabling TorchDispatchMode before executing "
|
||||
"Python op registration.");
|
||||
} else {
|
||||
// Normal codepath
|
||||
v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
|
||||
const auto& var = THPVariable_Unpack(v);
|
||||
var.unsafeGetTensorImpl()->init_pyobj(
|
||||
self_interpreter.get(), obj, status);
|
||||
if (check_has_torch_dispatch(obj)) {
|
||||
var.unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
return obj;
|
||||
|
@ -69,6 +69,7 @@ inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
|
||||
}
|
||||
|
||||
TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
|
||||
TORCH_PYTHON_API bool isMainPyInterpreter();
|
||||
|
||||
std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
|
||||
const c10::OperatorHandle& op,
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/FuncTorchTLS.h>
|
||||
#include <ATen/TensorSubclassLikeUtils.h>
|
||||
#include <ATen/core/PythonOpRegistrationTrampoline.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
@ -11,6 +12,7 @@
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
@ -23,6 +25,14 @@ namespace torch {
|
||||
namespace impl {
|
||||
namespace dispatch {
|
||||
|
||||
// NB: I'd like to index this on OperatorHandle, but I can't, as I can't
|
||||
// guarantee that the main interpreter has finish doing all registrations before
|
||||
// the other interpreters start banging on it
|
||||
static ska::flat_hash_map<
|
||||
c10::OperatorName,
|
||||
ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
|
||||
python_registrations_;
|
||||
|
||||
torch::Library::Kind parseKind(const std::string& k) {
|
||||
static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
|
||||
{"DEF", torch::Library::DEF},
|
||||
@ -58,19 +68,101 @@ inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
||||
}
|
||||
}
|
||||
|
||||
struct EnableHermeticPyObject {
|
||||
EnableHermeticPyObject()
|
||||
: old_(c10::impl::HermeticPyObjectTLS::get_state()),
|
||||
old_excluded_python_(
|
||||
c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
|
||||
old_python_(
|
||||
c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
|
||||
old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
|
||||
at::DispatchKey::PythonTLSSnapshot)) {
|
||||
c10::impl::HermeticPyObjectTLS::set_state(true);
|
||||
c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
|
||||
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
|
||||
c10::impl::tls_set_dispatch_key_included(
|
||||
at::DispatchKey::PythonTLSSnapshot, false);
|
||||
}
|
||||
~EnableHermeticPyObject() {
|
||||
c10::impl::HermeticPyObjectTLS::set_state(old_);
|
||||
c10::impl::tls_set_dispatch_key_excluded(
|
||||
at::DispatchKey::Python, old_excluded_python_);
|
||||
c10::impl::tls_set_dispatch_key_included(
|
||||
at::DispatchKey::Python, old_python_);
|
||||
c10::impl::tls_set_dispatch_key_included(
|
||||
at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
|
||||
}
|
||||
bool old_;
|
||||
bool old_excluded_python_;
|
||||
bool old_python_;
|
||||
bool old_python_snapshot_;
|
||||
};
|
||||
|
||||
class PythonKernelHolder : public c10::OperatorKernel {
|
||||
c10::SafePyObject func_;
|
||||
c10::DispatchKey dispatch_key_;
|
||||
|
||||
public:
|
||||
PythonKernelHolder(py::object func)
|
||||
: func_(func.release().ptr(), getPyInterpreter()) {}
|
||||
PythonKernelHolder(py::object func, c10::DispatchKey dispatch_key)
|
||||
: func_(func.release().ptr(), getPyInterpreter()),
|
||||
dispatch_key_(dispatch_key) {}
|
||||
|
||||
void operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
// Figure out if we can handle it hermetically, or if we have
|
||||
// to double dispatch
|
||||
|
||||
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
|
||||
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
|
||||
if (mode_stack_len > 0) {
|
||||
const auto& cur_torch_dispatch_mode_state =
|
||||
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
||||
cur_torch_dispatch_mode_state->pyinterpreter()
|
||||
->python_op_registration_trampoline(op, dispatch_key_, stack);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& schema = op.schema();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
// Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
|
||||
// means it's a nontrivial tensor subclass)
|
||||
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
|
||||
if (ivalue.isTensor()) {
|
||||
auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
|
||||
if (interpreter &&
|
||||
ivalue.unsafeToTensorImpl()->key_set().has(
|
||||
at::DispatchKey::Python)) {
|
||||
(*interpreter)
|
||||
->python_op_registration_trampoline(op, dispatch_key_, stack);
|
||||
return;
|
||||
}
|
||||
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
|
||||
// NB: use toListRef as it doesn't induce refcount bumps
|
||||
// (toTensorListRef is not a thing)
|
||||
for (const auto& nv : ivalue.toListRef()) {
|
||||
if (nv.isNone()) {
|
||||
continue;
|
||||
}
|
||||
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
|
||||
if (interpreter &&
|
||||
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
|
||||
(*interpreter)
|
||||
->python_op_registration_trampoline(op, dispatch_key_, stack);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing requires the operator to be homed to a specific interpreter, so
|
||||
// run it on the current interpreter
|
||||
|
||||
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
||||
py::gil_scoped_acquire g;
|
||||
EnableHermeticPyObject g2;
|
||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||
auto obj = py::reinterpret_steal<py::object>(PyObject_Call(
|
||||
func_.ptr(getPyInterpreter()),
|
||||
@ -83,6 +175,14 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
||||
}
|
||||
};
|
||||
|
||||
torch::_RegisterOrVerify register_or_verify() {
|
||||
if (isMainPyInterpreter()) {
|
||||
return torch::_RegisterOrVerify::REGISTER;
|
||||
} else {
|
||||
return torch::_RegisterOrVerify::VERIFY;
|
||||
}
|
||||
}
|
||||
|
||||
void initDispatchBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
@ -91,9 +191,12 @@ void initDispatchBindings(PyObject* module) {
|
||||
|
||||
// TODO: figure out how to do chaining
|
||||
py::class_<torch::Library>(m, "_DispatchModule")
|
||||
// Some of these APIs are only for testing and do not work in multipy
|
||||
// environment
|
||||
.def(
|
||||
"def_",
|
||||
[](py::object self, const char* schema, const char* alias) {
|
||||
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
||||
self.cast<torch::Library&>().def(
|
||||
torch::schema(schema, parseAliasAnalysisKind(alias)));
|
||||
return self;
|
||||
@ -107,6 +210,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
.def(
|
||||
"def_legacy",
|
||||
[](py::object self, const char* schema) {
|
||||
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
||||
self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
|
||||
return self;
|
||||
},
|
||||
@ -126,6 +230,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
const char* name,
|
||||
const char* dispatch,
|
||||
const char* debug) {
|
||||
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
||||
self.cast<torch::Library&>().def(
|
||||
name, dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||
return a;
|
||||
@ -143,6 +248,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
const char* dispatch,
|
||||
const char* alias,
|
||||
const char* debug) {
|
||||
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
||||
self.cast<torch::Library&>().def(
|
||||
torch::schema(schema, parseAliasAnalysisKind(alias)),
|
||||
dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||
@ -163,6 +269,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
const char* name,
|
||||
const char* dispatch,
|
||||
const char* debug) {
|
||||
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
||||
self.cast<torch::Library&>().impl(
|
||||
name, dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||
return a;
|
||||
@ -173,38 +280,26 @@ void initDispatchBindings(PyObject* module) {
|
||||
py::arg("name"),
|
||||
py::arg("dispatch") = "",
|
||||
py::arg("debug") = "impl_t_t")
|
||||
.def(
|
||||
"impl_tt_t",
|
||||
[](py::object self,
|
||||
const char* name,
|
||||
const char* dispatch,
|
||||
const char* debug) {
|
||||
self.cast<torch::Library&>().impl(
|
||||
name,
|
||||
dispatch_str(
|
||||
dispatch,
|
||||
[](const at::Tensor& a, const at::Tensor& b) { return a; })
|
||||
.debug(debug));
|
||||
return self;
|
||||
},
|
||||
"",
|
||||
py::arg("name"),
|
||||
py::arg("dispatch") = "",
|
||||
py::arg("debug") = "")
|
||||
.def(
|
||||
"impl",
|
||||
[](py::object self,
|
||||
const char* name,
|
||||
const char* dispatch,
|
||||
// TODO: empty string no longer works
|
||||
c10::DispatchKey dispatch,
|
||||
py::object func) {
|
||||
HANDLE_TH_ERRORS
|
||||
self.cast<torch::Library&>().impl(
|
||||
auto& lib = self.cast<torch::Library&>();
|
||||
lib.impl(
|
||||
name,
|
||||
dispatch_str(
|
||||
torch::dispatch(
|
||||
dispatch,
|
||||
CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<PythonKernelHolder>(
|
||||
std::move(func)))));
|
||||
std::make_unique<PythonKernelHolder>(func, dispatch))),
|
||||
register_or_verify());
|
||||
python_registrations_[lib._resolve(name)].insert_or_assign(
|
||||
dispatch,
|
||||
std::make_shared<c10::SafePyObject>(
|
||||
func.release().ptr(), getPyInterpreter()));
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
},
|
||||
"",
|
||||
@ -214,8 +309,11 @@ void initDispatchBindings(PyObject* module) {
|
||||
.def(
|
||||
"define",
|
||||
[](py::object self, const char* schema, const char* alias_analysis) {
|
||||
auto parsed_schema =
|
||||
torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
|
||||
self.cast<torch::Library&>().def(
|
||||
torch::schema(schema, parseAliasAnalysisKind(alias_analysis)));
|
||||
std::move(parsed_schema), {}, register_or_verify());
|
||||
// TODO: this is dumb, had to make a second copy
|
||||
return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
|
||||
.name();
|
||||
},
|
||||
@ -225,6 +323,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
.def(
|
||||
"fallback_fallthrough",
|
||||
[](py::object self, const char* dispatch) {
|
||||
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
||||
self.cast<torch::Library&>().fallback(
|
||||
dispatch_str(dispatch, CppFunction::makeFallthrough()));
|
||||
return self;
|
||||
@ -488,6 +587,9 @@ void initDispatchBindings(PyObject* module) {
|
||||
},
|
||||
py::arg("dispatch_key") = static_cast<const char*>(""));
|
||||
|
||||
m.def(
|
||||
"_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
|
||||
|
||||
m.def("_are_functorch_transforms_active", []() {
|
||||
auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
|
||||
return (
|
||||
@ -496,6 +598,26 @@ void initDispatchBindings(PyObject* module) {
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: dedupe with the kernel
|
||||
void python_op_registration_trampoline_impl(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKey key,
|
||||
torch::jit::Stack* stack) {
|
||||
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
||||
py::gil_scoped_acquire g;
|
||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||
const auto& func = python_registrations_[op.operator_name()][key];
|
||||
TORCH_INTERNAL_ASSERT(func != nullptr);
|
||||
auto* pyobj = func->ptr(getPyInterpreter());
|
||||
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
|
||||
auto obj = py::reinterpret_steal<py::object>(
|
||||
PyObject_Call(pyobj, args_kwargs.first.ptr(), args_kwargs.second.ptr()));
|
||||
if (!obj) {
|
||||
throw python_error();
|
||||
}
|
||||
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
|
||||
}
|
||||
|
||||
} // namespace dispatch
|
||||
} // namespace impl
|
||||
} // namespace torch
|
||||
|
@ -7,6 +7,11 @@ namespace dispatch {
|
||||
|
||||
void initDispatchBindings(PyObject* module);
|
||||
|
||||
}
|
||||
void python_op_registration_trampoline_impl(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKey key,
|
||||
torch::jit::Stack* stack);
|
||||
|
||||
} // namespace dispatch
|
||||
} // namespace impl
|
||||
} // namespace torch
|
||||
|
@ -86,6 +86,12 @@ namespace torch {
|
||||
struct NoInferSchemaTag {};
|
||||
#endif
|
||||
|
||||
// For multipy/torchdeploy use case
|
||||
enum class _RegisterOrVerify {
|
||||
REGISTER,
|
||||
VERIFY
|
||||
};
|
||||
|
||||
template <class CurClass>
|
||||
class class_;
|
||||
|
||||
@ -591,9 +597,9 @@ class TORCH_API Library final {
|
||||
/// ```
|
||||
|
||||
template <typename Schema>
|
||||
Library& def(Schema&& raw_schema, const std::vector<at::Tag>& tags = {}) & {
|
||||
Library& def(Schema&& raw_schema, const std::vector<at::Tag>& tags = {}, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
|
||||
c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
|
||||
return _def(std::move(s), nullptr, tags);
|
||||
return _def(std::move(s), nullptr, tags, rv);
|
||||
}
|
||||
/// Define an operator for a schema and then register an implementation for
|
||||
/// it. This is typically what you would use if you aren't planning
|
||||
@ -644,7 +650,7 @@ class TORCH_API Library final {
|
||||
/// }
|
||||
/// ```
|
||||
template <typename Name, typename Func>
|
||||
Library& impl(Name name, Func&& raw_f) & {
|
||||
Library& impl(Name name, Func&& raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
|
||||
// TODO: need to raise an error when you impl a function that has a
|
||||
// catch all def
|
||||
#if defined C10_MOBILE
|
||||
@ -652,7 +658,7 @@ class TORCH_API Library final {
|
||||
#else
|
||||
CppFunction f(std::forward<Func>(raw_f));
|
||||
#endif
|
||||
return _impl(name, std::move(f));
|
||||
return _impl(name, std::move(f), rv);
|
||||
}
|
||||
|
||||
#if defined C10_MOBILE
|
||||
@ -673,6 +679,10 @@ class TORCH_API Library final {
|
||||
}
|
||||
#endif
|
||||
|
||||
// Helper for getting an OperatorName for a const char*. You probably
|
||||
// don't need this.
|
||||
c10::OperatorName _resolve(const char* name) const;
|
||||
|
||||
/// \private
|
||||
///
|
||||
/// Convenience overload for directly specifying the dispatch key when
|
||||
@ -809,12 +819,17 @@ class TORCH_API Library final {
|
||||
Library& _def(
|
||||
c10::FunctionSchema&& schema,
|
||||
c10::OperatorName* out_name = nullptr,
|
||||
const std::vector<at::Tag>& tags = {}) &;
|
||||
const std::vector<at::Tag>& tags = {},
|
||||
_RegisterOrVerify rv = _RegisterOrVerify::REGISTER
|
||||
) &;
|
||||
Library& _def(
|
||||
c10::either<c10::OperatorName, c10::FunctionSchema>&&,
|
||||
CppFunction&& f) &;
|
||||
Library& _impl(const char* name, CppFunction&& f) &;
|
||||
Library& _impl(const char* name, CppFunction&& f,
|
||||
_RegisterOrVerify rv = _RegisterOrVerify::REGISTER) &;
|
||||
Library& _fallback(CppFunction&& f) &;
|
||||
|
||||
at::OperatorName _parseNameForLib(const char* name_str) const;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
@ -2,7 +2,6 @@ from ._ops import OpOverload
|
||||
from typing import Set
|
||||
import traceback
|
||||
import torch
|
||||
import os
|
||||
|
||||
__all__ = ['Library', 'impl', 'define']
|
||||
|
||||
@ -30,9 +29,6 @@ class Library:
|
||||
dispatch_key: PyTorch dispatch key (default: "")
|
||||
"""
|
||||
def __init__(self, ns, kind, dispatch_key=""):
|
||||
if os.environ.get('PYTORCH_DISABLE_LIBRARY', "0") == "1":
|
||||
raise RuntimeError("Trying to use torch.library in an environment where it is disabled")
|
||||
|
||||
if kind != "IMPL" and kind != "DEF":
|
||||
raise ValueError("Unsupported kind: ", kind)
|
||||
|
||||
@ -126,7 +122,8 @@ class Library:
|
||||
" Instead we should let the operator decompose, and ensure that we have meta kernels"
|
||||
" for the base ops that it decomposes into.")
|
||||
|
||||
self.m.impl(name, dispatch_key, fn)
|
||||
self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn)
|
||||
|
||||
_impls.add(key)
|
||||
self._op_impls.add(key)
|
||||
|
||||
|
Reference in New Issue
Block a user