mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
We want to make TorchRec sharded models TorchScriptable. TorchRec sharded models uses generic types Awaitable[W] and LazyAwaitable[W] (https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L212). In sharded model those types are used instead of contained type W, having the initialization function that produces object of type W. At the moment when the first attribute of W is requested - `LazyAwaitable[W]` will call its initialization function (on the same stack), cache the result inside and work transparently as an object of W. So we can think about it as a delayed object initialization. To support this behavior in TorchScript - we propose a new type to TorchScript - `Await`. In eager mode it works the same as `LazyAwaitable[W]` in TorchRec, being dynamically typed - acting as a type `W` while it is `Await[W]`. Within torchscript it is `Await[W]` and can be only explicitly converted to W, using special function `torch.jit.awaitable_wait(aw)`. Creation of this `Await[W]` is done via another special function `torch.jit.awaitable(func, *args)`. The semantic is close to `torch.jit.Future`, fork, wait and uses the same jit mechanics (inline fork Closures) with the difference that it does not start this function in parallel on fork. It only stores as a lambda inside IValue that will be called on the same thread when `torch.jit.awaitable_wait` is called. For example (more examples in this PR `test/jit/test_await.py`) ``` def delayed(z: Tensor) -> Tensor: return Tensor * 3 @torch.jit.script def fn(x: Tensor): aw: Await[int] = torch.jit._awaitable(delayed, 99) a = torch.eye(2) b = torch.jit._awaitable_wait(aw) return a + b + x ``` Functions semantics: `_awaitable(func -> Callable[Tuple[...], W], *args, **kwargs) -> Await[W]` Creates Await object, owns args and kwargs. Once _awaitable_wait calls, executes function func and owns the result of the function. Following _awaitable_wait calls will return this result from the first function call. `_awaitable_wait(Await[W]) -> W` Returns either cached result of W if it is not the first _awaitable_wait call to this Await object or calls specified function if the first. `_awaitable_nowait(W) -> Await[W]` Creates trivial Await[W] wrapper on specified object To be type complaint for the corner cases. Differential Revision: [D42502706](https://our.internmc.facebook.com/intern/diff/D42502706) Pull Request resolved: https://github.com/pytorch/pytorch/pull/90863 Approved by: https://github.com/davidberard98
452 lines
15 KiB
C++
452 lines
15 KiB
C++
#include <torch/csrc/jit/runtime/operator.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/core/alias_info.h>
|
|
#include <ATen/core/interned_strings.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/frontend/edit_distance.h>
|
|
|
|
#include <queue>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace {
|
|
using OperatorMap =
|
|
std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
|
|
struct OperatorRegistry {
|
|
private:
|
|
std::mutex lock;
|
|
OperatorMap operators;
|
|
// list of operators whose schema have not yet been parsed, and must
|
|
// be registered before any call to lookup an operator
|
|
std::vector<std::shared_ptr<Operator>> to_register;
|
|
// Those two maps are used to implement lookupByLiteral, which is needed for
|
|
// the n->match(...) calls. Basically, every function schema is assigned a
|
|
// unique string you can use to match it. However, parsing those strings or
|
|
// comparing and hashing them character by character would be very slow, so we
|
|
// use a trick here! Every string literal in your program is guaranteed to
|
|
// have static storage duration and so its address won't change at runtime.
|
|
// This allows us to memoize answers for every pointer, which is done by the
|
|
// operators_by_sig_literal map. Still, this map is initially empty, and so we
|
|
// still need to do the complete string matching at the first time, which is
|
|
// implemented by performing a lookup in the operators_by_sig map.
|
|
std::unordered_map<std::string, std::shared_ptr<Operator>> operators_by_sig;
|
|
std::unordered_map<const char*, std::shared_ptr<Operator>>
|
|
operators_by_sig_literal;
|
|
|
|
// Remember all registered operator names to check that they aren't
|
|
// registered a second time. Registering an op multiple times is
|
|
// fragile because it might depend on static initialization order
|
|
// which one is picked at runtime.
|
|
#ifdef C10_MOBILE
|
|
std::unordered_set<c10::OperatorName> registered_operator_names;
|
|
#endif
|
|
|
|
// XXX - caller must be holding lock
|
|
void registerPendingOperators() {
|
|
for (const auto& op : to_register) {
|
|
Symbol sym = Symbol::fromQualString(op->schema().name());
|
|
operators[sym].push_back(op);
|
|
operators_by_sig[canonicalSchemaString(op->schema())] = op;
|
|
}
|
|
to_register.clear();
|
|
}
|
|
|
|
public:
|
|
void registerOperator(Operator&& op) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
#ifdef C10_MOBILE
|
|
TORCH_INTERNAL_ASSERT(
|
|
0 == registered_operator_names.count(op.schema().operator_name()),
|
|
"Tried to register operator \"",
|
|
op.schema(),
|
|
"\" to JIT but the operator name was already registered before. Please add or change the overload name.");
|
|
registered_operator_names.insert(op.schema().operator_name());
|
|
#endif
|
|
to_register.push_back(std::make_shared<Operator>(std::move(op)));
|
|
}
|
|
|
|
void deregisterOperator(const FunctionSchema& schema) {
|
|
Symbol sym = Symbol::fromQualString(schema.name());
|
|
auto sig = canonicalSchemaString(schema);
|
|
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
#ifdef C10_MOBILE
|
|
TORCH_INTERNAL_ASSERT(
|
|
1 == registered_operator_names.count(schema.operator_name()),
|
|
"Tried to remove operator ",
|
|
schema,
|
|
" from JIT but it wasn't found.");
|
|
registered_operator_names.erase(schema.operator_name());
|
|
#endif
|
|
// Try removing from pending operators list first
|
|
auto pending_it = to_register.begin();
|
|
while (pending_it != to_register.end() && (*pending_it)->schema() != schema)
|
|
++pending_it;
|
|
|
|
if (pending_it != to_register.end()) {
|
|
to_register.erase(pending_it);
|
|
return;
|
|
}
|
|
|
|
// Remove operator from signature map
|
|
auto sig_it = operators_by_sig.find(sig);
|
|
if (sig_it == operators_by_sig.end()) {
|
|
return;
|
|
}
|
|
|
|
operators_by_sig.erase(sig_it);
|
|
|
|
// Remove operator from symbol map
|
|
auto op_it = operators.find(sym);
|
|
TORCH_CHECK(
|
|
op_it != operators.end(),
|
|
"operator with signature ",
|
|
sig,
|
|
" is missing from symbol registry");
|
|
|
|
auto& op_vec = op_it->second;
|
|
auto it = op_vec.begin();
|
|
while (it != op_vec.end() && (*it)->schema() != schema)
|
|
++it;
|
|
if (it != op_vec.end()) {
|
|
op_vec.erase(it);
|
|
}
|
|
if (op_vec.empty()) {
|
|
operators.erase(op_it);
|
|
}
|
|
}
|
|
|
|
const std::shared_ptr<Operator>& lookupByLiteral(const char* name) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
registerPendingOperators();
|
|
auto it = operators_by_sig_literal.find(name);
|
|
if (it == operators_by_sig_literal.end()) {
|
|
auto op_ptr_it =
|
|
operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
|
|
// Handy debugging code that dumps all operators we know about on mismatch
|
|
#if 0
|
|
if (op_ptr_it == operators_by_sig.end()) {
|
|
for (auto & entry : operators_by_sig) {
|
|
std::cout << entry.first << std::endl;
|
|
}
|
|
}
|
|
#endif
|
|
TORCH_CHECK(
|
|
op_ptr_it != operators_by_sig.end(),
|
|
"Couldn't find an operator for ",
|
|
name,
|
|
". Do you have to update a set of hardcoded JIT ops?");
|
|
it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second);
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
const std::vector<std::shared_ptr<Operator>>& getOperators(Symbol name) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
registerPendingOperators();
|
|
static std::vector<std::shared_ptr<Operator>> empty;
|
|
auto it = operators.find(name);
|
|
if (it != operators.end())
|
|
return it->second;
|
|
return empty;
|
|
}
|
|
|
|
std::vector<Symbol> findSimilarOperators(Symbol input_op) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
registerPendingOperators();
|
|
|
|
using EntryPair = std::pair<int64_t, Symbol>;
|
|
auto cmp = [](const EntryPair& lhs, const EntryPair& rhs) {
|
|
return lhs.first > rhs.first;
|
|
};
|
|
|
|
std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)>
|
|
rankings(cmp);
|
|
static constexpr size_t MAX_EDIT_DIST = 2u;
|
|
for (const auto& op : operators) {
|
|
auto edit_dist = ComputeEditDistance(
|
|
input_op.toQualString(), op.first.toQualString(), MAX_EDIT_DIST);
|
|
if (edit_dist <= MAX_EDIT_DIST) {
|
|
rankings.emplace(edit_dist, op.first);
|
|
}
|
|
}
|
|
std::vector<Symbol> ret;
|
|
while (!rankings.empty()) {
|
|
ret.push_back(rankings.top().second);
|
|
rankings.pop();
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
const std::vector<std::shared_ptr<Operator>> getAllOperators() {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
registerPendingOperators();
|
|
std::vector<std::shared_ptr<Operator>> values;
|
|
values.clear();
|
|
for (auto& kv : operators) {
|
|
values.insert(values.end(), kv.second.begin(), kv.second.end());
|
|
}
|
|
return values;
|
|
}
|
|
};
|
|
|
|
OperatorRegistry& getRegistry() {
|
|
static OperatorRegistry r;
|
|
return r;
|
|
}
|
|
|
|
bool printerHasSpecialCaseFor(Symbol sym) {
|
|
using namespace at;
|
|
// WARNING: by adding a value to this set, you are asserting
|
|
// that you have also added special handling of this symbol to
|
|
// the python_print.cpp. Not adding handling will cause import and export
|
|
// of modules with this new operator to fail. This is only required
|
|
// for operators without schema. Prefer registering your operator with
|
|
// schema to editing this list here. These cases should only be things
|
|
// that require special handling because they do not fit normal schema
|
|
const static std::unordered_set<Symbol> handled = {
|
|
prim::Constant, prim::Uninitialized, prim::fork,
|
|
prim::awaitable, prim::ListConstruct, prim::DictConstruct,
|
|
prim::ListUnpack, prim::Print, prim::PythonOp,
|
|
prim::TupleConstruct, prim::TupleIndex, prim::TupleSlice,
|
|
prim::TupleUnpack, prim::CreateObject, prim::GetAttr,
|
|
prim::SetAttr, prim::CallFunction, prim::isinstance,
|
|
prim::unchecked_cast, prim::tolist, prim::rpc_async,
|
|
prim::rpc_sync, prim::rpc_remote};
|
|
|
|
// WARNING: by adding a value to this set, you are asserting that your
|
|
// primitive is only ever added during optimization and does not need
|
|
// to be correctly printed for export (a process that happens before
|
|
// optimization passes run)
|
|
const static std::unordered_set<Symbol> unneeded = {
|
|
c10::onnx::Reshape, // only used in onnx
|
|
c10::onnx::Shape, // only used in onnx
|
|
prim::AutogradZero, // temporarily inserted by autograd
|
|
prim::AutogradAnyNonZero, // temporarily inserted by autograd
|
|
prim::AutogradAllNonZero, // temporarily inserted by autograd
|
|
prim::AutogradAllZero, // temporarily inserted by autograd
|
|
prim::AutogradAdd, // temporarily inserted by autograd
|
|
prim::ConstantChunk, // optimization pass adds it
|
|
prim::DifferentiableGraph, // optimization pass adds it,
|
|
prim::FunctionalGraph, // optimization pass adds it,
|
|
prim::ReductionSizes, // optimization pass (fuser) adds it
|
|
prim::BroadcastSizes, // optimization pass (fuser) adds it
|
|
prim::ChunkSizes, // optimization pass (fuser) adds it
|
|
prim::Drop, // used in interpreter only
|
|
prim::FusedConcat, // optimization pass adds it
|
|
prim::FusionGroup, // optimization pass adds it
|
|
prim::CudaFusionGroup, // optimization pass adds it
|
|
prim::CudaFusionGuard, // optimization pass adds it
|
|
prim::TensorExprGroup, // optimization pass adds it
|
|
prim::TensorExprDynamicGroup, // optimization pass adds it
|
|
prim::StaticSubgraph, // optimization pass adds it
|
|
prim::ConstantMKLDNNTensor, // optimization pass adds it
|
|
prim::BroadcastMKLDNNTensors, // optimization pass adds it
|
|
prim::oneDNNFusionGroup, // optimization pass adds it
|
|
prim::oneDNNFusionGuard, // optimization pass adds it
|
|
prim::StaticRuntimeCopyOuts, // used in SR only
|
|
prim::Load, // used in interpreter only
|
|
prim::MMTreeReduce, // used as an optimization
|
|
prim::MMBatchSide, // used as an optimization
|
|
prim::Store, // used in interpreter only
|
|
prim::profile, // used in interpreter only
|
|
prim::profile_ivalue, // used in interpreter only
|
|
prim::TypeCheck, // used in interpreter only
|
|
prim::RequiresGradCheck, // used in interpreter only
|
|
prim::FallbackGraph, // converted into prim::CallFunction
|
|
|
|
};
|
|
|
|
// These namespaces are required to have Python printers unless
|
|
// otherwise noted in unneeded.
|
|
const static std::unordered_set<Symbol> required_namespaces = {
|
|
c10::namespaces::prim,
|
|
c10::namespaces::aten,
|
|
c10::namespaces::onnx,
|
|
};
|
|
|
|
return handled.count(sym) || unneeded.count(sym) ||
|
|
!required_namespaces.count(sym.ns());
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
|
using namespace at;
|
|
// WARNING: by adding a case to this list, you are asserting that you have
|
|
// added a case for the unschematized node in AliasDb::analyze
|
|
const static std::unordered_set<Symbol> handled = {
|
|
prim::If,
|
|
prim::Loop,
|
|
prim::FusionGroup,
|
|
prim::CudaFusionGroup,
|
|
prim::oneDNNFusionGroup,
|
|
prim::DifferentiableGraph,
|
|
prim::TensorExprGroup,
|
|
prim::TensorExprDynamicGroup,
|
|
prim::StaticSubgraph,
|
|
prim::FunctionalGraph,
|
|
prim::Constant,
|
|
prim::Uninitialized,
|
|
prim::DictConstruct,
|
|
prim::ListConstruct,
|
|
prim::TupleConstruct,
|
|
prim::AutogradZero,
|
|
prim::FusedConcat,
|
|
prim::GradOf,
|
|
prim::MMTreeReduce,
|
|
prim::MMBatchSide,
|
|
prim::BroadcastSizes,
|
|
prim::ChunkSizes,
|
|
prim::Closure,
|
|
prim::TupleUnpack,
|
|
prim::TupleIndex,
|
|
prim::TupleSlice,
|
|
prim::ListUnpack,
|
|
prim::PythonOp,
|
|
prim::ConstantChunk,
|
|
prim::BroadcastingChunk,
|
|
prim::MKLDNNGroup,
|
|
prim::ConstantMKLDNNTensor,
|
|
prim::BroadcastMKLDNNTensors,
|
|
prim::fork,
|
|
prim::awaitable,
|
|
prim::awaitable_nowait,
|
|
prim::awaitable_wait,
|
|
prim::CreateObject,
|
|
prim::AutogradAdd,
|
|
prim::GetAttr,
|
|
prim::SetAttr,
|
|
prim::profile,
|
|
prim::profile_ivalue,
|
|
prim::TypeCheck,
|
|
prim::RequiresGradCheck,
|
|
prim::Print,
|
|
prim::CallFunction,
|
|
prim::CallMethod,
|
|
aten::wait,
|
|
prim::isinstance,
|
|
prim::unchecked_cast,
|
|
prim::tolist,
|
|
prim::rpc_async,
|
|
prim::rpc_sync,
|
|
prim::rpc_remote,
|
|
prim::Enter,
|
|
prim::Exit,
|
|
prim::FallbackGraph,
|
|
};
|
|
|
|
// Operators that should not be used by alias analysis
|
|
const static std::unordered_set<Symbol> purposefully_not_handled = {
|
|
prim::Load,
|
|
prim::Store,
|
|
prim::Drop,
|
|
at::onnx::Reshape,
|
|
at::onnx::Shape,
|
|
prim::AutogradAdd,
|
|
};
|
|
|
|
return handled.count(symbol) || purposefully_not_handled.count(symbol);
|
|
}
|
|
|
|
void registerOperator(Operator&& op) {
|
|
if (op.schema().is_varret()) {
|
|
Symbol s = Symbol::fromQualString(op.schema().name());
|
|
if (!printerHasSpecialCaseFor(s)) {
|
|
AT_ERROR(
|
|
"Missing special case in python printer for non-schematized"
|
|
" operator ",
|
|
op.schema().name(),
|
|
". File a bug to add a case for this operator.\n");
|
|
}
|
|
if (aliasAnalysisHasSpecialCaseFor(s) &&
|
|
op.aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE) {
|
|
AT_ERROR(
|
|
"Conflict in special casing in alias analysis for non-schematized"
|
|
" operator ",
|
|
op.schema().name(),
|
|
". File a bug to add a case for this operator.\n");
|
|
}
|
|
if (aliasAnalysisHasSpecialCaseFor(s) &&
|
|
op.aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA) {
|
|
AT_ERROR(
|
|
"The operator ",
|
|
op.schema().name(),
|
|
" is special cased and cannot use explicit alias analysis.");
|
|
}
|
|
}
|
|
getRegistry().registerOperator(std::move(op));
|
|
}
|
|
|
|
void deregisterOperator(const FunctionSchema& schema) {
|
|
getRegistry().deregisterOperator(schema);
|
|
}
|
|
|
|
const std::vector<std::shared_ptr<Operator>> getAllOperators() {
|
|
return getRegistry().getAllOperators();
|
|
}
|
|
|
|
const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name) {
|
|
return getRegistry().getOperators(name);
|
|
}
|
|
|
|
std::shared_ptr<Operator> findOperatorFor(const c10::OperatorName& full_name) {
|
|
for (const auto& op :
|
|
getRegistry().getOperators(Symbol::fromQualString(full_name.name))) {
|
|
if (op->schema().overload_name() == full_name.overload_name) {
|
|
return op;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<Symbol> findSimilarOperators(Symbol input_op) {
|
|
return getRegistry().findSimilarOperators(input_op);
|
|
}
|
|
|
|
std::shared_ptr<Operator> getOperatorForLiteral(const char* signature) {
|
|
return getRegistry().lookupByLiteral(signature);
|
|
}
|
|
|
|
std::string canonicalSchemaString(const FunctionSchema& schema) {
|
|
std::string out = schema.name();
|
|
out.push_back('(');
|
|
|
|
bool seen_kwarg_only = false;
|
|
for (const auto i : c10::irange(schema.arguments().size())) {
|
|
if (i > 0) {
|
|
out += ", ";
|
|
}
|
|
if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
|
|
out += "*, ";
|
|
seen_kwarg_only = true;
|
|
}
|
|
const auto& arg = schema.arguments()[i];
|
|
out += arg.type()->str();
|
|
out.push_back(' ');
|
|
out += arg.name();
|
|
}
|
|
|
|
out += ") -> ";
|
|
if (schema.returns().size() == 1) {
|
|
out += schema.returns().at(0).type()->str();
|
|
} else if (schema.returns().size() > 1) {
|
|
out.push_back('(');
|
|
for (const auto i : c10::irange(schema.returns().size())) {
|
|
if (i > 0) {
|
|
out += ", ";
|
|
}
|
|
out += schema.returns()[i].type()->str();
|
|
}
|
|
out.push_back(')');
|
|
}
|
|
return out;
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|