mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This PR adds the final set of clang-tidy checks we should add for our codebase: a last set of performance-related checks. Most fixes here are around changing `auto` to `const auto&` in a few places where unnecessary copies were made, and adding `reserve()` calls before loops doing repeated `push_back()`. Also a few cases of calling `std::string::find` with a single-character string literal instead of a single char, which uses a less efficient string search algorithm meant for searching larger substrings.  ezyang apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/15198 Differential Revision: D13468797 Pulled By: goldsborough fbshipit-source-id: 2bed1ea1c7c162b7f3e0e1026f17125e88c4d5b2
578 lines
18 KiB
C++
578 lines
18 KiB
C++
#include <ATen/ATen.h>
|
|
#include <torch/csrc/jit/alias_info.h>
|
|
#include <torch/csrc/jit/script/lexer.h>
|
|
#include <torch/csrc/jit/script/parse_string_literal.h>
|
|
#include <torch/csrc/jit/script/tree.h>
|
|
#include <torch/csrc/jit/operator.h>
|
|
#include <torch/csrc/jit/passes/python_print.h>
|
|
#include <torch/csrc/jit/script/error_report.h>
|
|
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
namespace script {
|
|
struct SchemaParser {
|
|
SchemaParser(const std::string& str)
|
|
: L(str) {}
|
|
|
|
FunctionSchema parseDeclaration() {
|
|
auto name = L.expect(TK_IDENT).text();
|
|
if(L.nextIf(':')) {
|
|
L.expect(':');
|
|
name = name + "::" + L.expect(TK_IDENT).text();
|
|
}
|
|
std::vector<Argument> arguments;
|
|
std::vector<Argument> returns;
|
|
bool kwarg_only = false;
|
|
bool is_vararg = false;
|
|
size_t idx = 0;
|
|
parseList('(', ',', ')', [&] {
|
|
if(is_vararg)
|
|
throw ErrorReport(L.cur()) << "... must be the last element of the argument list";
|
|
if (L.nextIf('*')) {
|
|
kwarg_only = true;
|
|
} else if(L.nextIf(TK_DOTS)) {
|
|
is_vararg = true;
|
|
} else {
|
|
arguments.push_back(parseArgument(
|
|
idx++, /*is_return=*/false, /*kwarg_only=*/kwarg_only));
|
|
}
|
|
});
|
|
idx = 0;
|
|
L.expect(TK_ARROW);
|
|
if (L.cur().kind == '(') {
|
|
parseList('(', ',', ')', [&] {
|
|
returns.push_back(
|
|
parseArgument(idx++, /*is_return=*/true, /*kwarg_only=*/false));
|
|
});
|
|
} else {
|
|
returns.push_back(
|
|
parseArgument(0, /*is_return=*/true, /*kwarg_only=*/false));
|
|
}
|
|
return FunctionSchema{
|
|
name, std::move(arguments), std::move(returns), is_vararg, false};
|
|
}
|
|
|
|
std::vector<FunctionSchema> parseDeclarations() {
|
|
std::vector<FunctionSchema> results;
|
|
do {
|
|
results.push_back(parseDeclaration());
|
|
} while(L.nextIf(TK_NEWLINE));
|
|
L.expect(TK_EOF);
|
|
return results;
|
|
}
|
|
|
|
TreeRef parseIdent() {
|
|
return String::create(L.expect(TK_IDENT).text());
|
|
}
|
|
TypePtr parseBaseType() {
|
|
static std::unordered_map<std::string, TypePtr> type_map = {
|
|
{"Generator", GeneratorType::get() },
|
|
{"ScalarType", IntType::get() },
|
|
{"Layout", IntType::get() },
|
|
{"Device", DeviceObjType::get() },
|
|
{"Scalar", NumberType::get() },
|
|
{"str", StringType::get() },
|
|
{"float", FloatType::get() },
|
|
{"int", IntType::get() },
|
|
{"bool", BoolType::get() },
|
|
};
|
|
auto tok = L.expect(TK_IDENT);
|
|
auto text = tok.text();
|
|
auto it = type_map.find(text);
|
|
if(it == type_map.end()) {
|
|
if(text.size() > 0 && islower(text[0])) {
|
|
// lower case identifiers that are not otherwise valid types
|
|
// are treated as type variables
|
|
return VarType::create(text);
|
|
}
|
|
throw ErrorReport(tok.range) << "unknown type specifier";
|
|
}
|
|
return it->second;
|
|
}
|
|
// Examples:
|
|
// Tensor(a) // Tensor is in set a
|
|
// Tensor(a!) // it is also written to
|
|
// Tensor! // shorthand for Tensor(fresh_identifier!)
|
|
// Tensor(a! -> a|b) // Tensor is in set a, written to,
|
|
// and after the write is in set a AND b.
|
|
c10::optional<AliasInfo> parseAliasAnnotation() {
|
|
std::set<Symbol> sets;
|
|
AliasInfo alias_info;
|
|
if (L.nextIf('(')) {
|
|
// optional 'alias set annotation'
|
|
parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
|
|
if (L.nextIf('*')) {
|
|
alias_info = AliasInfo::createWildcard();
|
|
|
|
// If we found a wildcard, ignore all subsequent annotations
|
|
} else if (!alias_info.isWildcard()) {
|
|
alias_info.addSet(
|
|
Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
|
|
}
|
|
});
|
|
if (L.nextIf('!')) {
|
|
alias_info.setIsWrite(true);
|
|
}
|
|
L.expect(')');
|
|
} else if (L.nextIf('!')) {
|
|
alias_info.addSet(
|
|
Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
|
|
alias_info.setIsWrite(true);
|
|
} else{
|
|
return c10::nullopt;
|
|
}
|
|
|
|
return alias_info;
|
|
}
|
|
|
|
std::pair<TypePtr, c10::optional<AliasInfo>> parseType() {
|
|
TypePtr value;
|
|
c10::optional<AliasInfo> alias_info;
|
|
// Tuple type
|
|
if (L.cur().kind == '(') {
|
|
std::vector<TypePtr> types;
|
|
parseList('(', ',', ')', [&] {
|
|
auto r = parseType();
|
|
types.push_back(std::move(r.first));
|
|
if (alias_info && r.second) {
|
|
alias_info->addContainedType(std::move(*r.second));
|
|
}
|
|
});
|
|
value = TupleType::create(std::move(types));
|
|
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
|
|
L.next(); // Future
|
|
L.expect('(');
|
|
auto p = parseType();
|
|
auto subtype = std::move(p.first);
|
|
auto subalias = std::move(p.second);
|
|
L.expect(')');
|
|
value = FutureType::create(subtype);
|
|
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
|
|
L.next();
|
|
value = DynamicType::get();
|
|
alias_info = parseAliasAnnotation();
|
|
} else {
|
|
value = parseBaseType();
|
|
}
|
|
while(true) {
|
|
if(L.cur().kind == '[' && L.lookahead().kind == ']') {
|
|
L.next(); // [
|
|
L.next(); // ]
|
|
value = ListType::create(value);
|
|
auto container = parseAliasAnnotation();
|
|
if (container && alias_info) {
|
|
container->addContainedType(std::move(*alias_info));
|
|
}
|
|
alias_info = std::move(container);
|
|
} else if(L.nextIf('?')) {
|
|
value = OptionalType::create(value);
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
return std::make_pair(std::move(value), std::move(alias_info));
|
|
}
|
|
|
|
Argument parseArgument(size_t idx, bool is_return, bool kwarg_only) {
|
|
Argument result;
|
|
auto p = parseType();
|
|
auto type = std::move(p.first);
|
|
auto alias_info = std::move(p.second);
|
|
c10::optional<int32_t> N;
|
|
c10::optional<IValue> default_value;
|
|
c10::optional<std::string> alias_set;
|
|
std::string name;
|
|
if(L.nextIf('[')) {
|
|
// note: an array with a size hint can only occur at the Argument level
|
|
type = ListType::create(type);
|
|
N = std::stoll(L.expect(TK_NUMBER).text());
|
|
L.expect(']');
|
|
auto container = parseAliasAnnotation();
|
|
if (container && alias_info) {
|
|
container->addContainedType(std::move(*alias_info));
|
|
}
|
|
alias_info = std::move(container);
|
|
}
|
|
if(is_return) {
|
|
// optionally named return values
|
|
if(L.cur().kind == TK_IDENT) {
|
|
name = L.next().text();
|
|
} else {
|
|
name = "ret" + std::to_string(idx);
|
|
}
|
|
} else {
|
|
name = L.expect(TK_IDENT).text();
|
|
if(L.nextIf('=')) {
|
|
default_value = parseDefaultValue(type, N);
|
|
}
|
|
}
|
|
return Argument(
|
|
std::move(name),
|
|
std::move(type),
|
|
N,
|
|
std::move(default_value),
|
|
!is_return && kwarg_only,
|
|
std::move(alias_info));
|
|
}
|
|
IValue parseSingleConstant(TypeKind kind) {
|
|
switch(L.cur().kind) {
|
|
case TK_TRUE:
|
|
L.next();
|
|
return true;
|
|
case TK_FALSE:
|
|
L.next();
|
|
return false;
|
|
case TK_NONE:
|
|
L.next();
|
|
return IValue();
|
|
case TK_STRINGLITERAL: {
|
|
auto token = L.next();
|
|
return parseStringLiteral(token.range, token.text());
|
|
}
|
|
case TK_IDENT: {
|
|
auto tok = L.next();
|
|
auto text = tok.text();
|
|
if("float" == text) {
|
|
return static_cast<int64_t>(at::kFloat);
|
|
} else if("strided" == text) {
|
|
return static_cast<int64_t>(at::kStrided);
|
|
} else if("Mean" == text) {
|
|
return static_cast<int64_t>(Reduction::Mean);
|
|
} else {
|
|
throw ErrorReport(L.cur().range) << "invalid numeric default value";
|
|
}
|
|
}
|
|
default:
|
|
std::string n;
|
|
if(L.nextIf('-'))
|
|
n = "-" + L.expect(TK_NUMBER).text();
|
|
else
|
|
n = L.expect(TK_NUMBER).text();
|
|
if(kind == TypeKind::FloatType || n.find('.') != std::string::npos || n.find('e') != std::string::npos) {
|
|
return std::stod(n);
|
|
} else {
|
|
int64_t v = std::stoll(n);
|
|
return v;
|
|
}
|
|
}
|
|
}
|
|
IValue convertToList(TypeKind kind, const SourceRange& range, std::vector<IValue> vs) {
|
|
switch(kind) {
|
|
case TypeKind::FloatType:
|
|
return fmap(vs, [](IValue v) {
|
|
return v.toDouble();
|
|
});
|
|
case TypeKind::IntType:
|
|
return fmap(vs, [](IValue v) {
|
|
return v.toInt();
|
|
});
|
|
case TypeKind::BoolType:
|
|
return fmap(vs, [](IValue v) {
|
|
return v.toBool();
|
|
});
|
|
default:
|
|
throw ErrorReport(range) << "lists are only supported for float or int types.";
|
|
}
|
|
}
|
|
IValue parseConstantList(TypeKind kind) {
|
|
auto tok = L.expect('[');
|
|
std::vector<IValue> vs;
|
|
if(L.cur().kind != ']') {
|
|
do {
|
|
vs.push_back(parseSingleConstant(kind));
|
|
} while(L.nextIf(','));
|
|
}
|
|
L.expect(']');
|
|
return convertToList(kind, tok.range, std::move(vs));
|
|
}
|
|
|
|
IValue parseTensorDefault(const SourceRange& range) {
|
|
L.expect(TK_NONE);
|
|
return IValue();
|
|
}
|
|
IValue parseDefaultValue(const TypePtr& arg_type, c10::optional<int32_t> arg_N) {
|
|
auto range = L.cur().range;
|
|
switch(arg_type->kind()) {
|
|
case TypeKind::DynamicType:
|
|
case TypeKind::GeneratorType: {
|
|
return parseTensorDefault(range);
|
|
} break;
|
|
case TypeKind::StringType:
|
|
case TypeKind::OptionalType:
|
|
case TypeKind::NumberType:
|
|
case TypeKind::IntType:
|
|
case TypeKind::BoolType:
|
|
case TypeKind::FloatType:
|
|
return parseSingleConstant(arg_type->kind());
|
|
break;
|
|
case TypeKind::DeviceObjType: {
|
|
auto device_text = parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
|
|
return c10::Device(device_text);
|
|
break;
|
|
}
|
|
case TypeKind::ListType: {
|
|
auto elem_kind = arg_type->cast<ListType>()->getElementType();
|
|
if(L.cur().kind == TK_IDENT) {
|
|
return parseTensorDefault(range);
|
|
} else if(arg_N && L.cur().kind != '[') {
|
|
IValue v = parseSingleConstant(elem_kind->kind());
|
|
std::vector<IValue> repeated(*arg_N, v);
|
|
return convertToList(elem_kind->kind(), range, repeated);
|
|
} else {
|
|
return parseConstantList(elem_kind->kind());
|
|
}
|
|
} break;
|
|
default:
|
|
throw ErrorReport(range) << "unexpected type, file a bug report";
|
|
}
|
|
return IValue(); // silence warnings
|
|
}
|
|
|
|
void parseList(int begin, int sep, int end, const std::function<void()>& callback) {
|
|
auto r = L.cur().range;
|
|
if (begin != TK_NOTHING)
|
|
L.expect(begin);
|
|
if (L.cur().kind != end) {
|
|
do {
|
|
callback();
|
|
} while (L.nextIf(sep));
|
|
}
|
|
if (end != TK_NOTHING)
|
|
L.expect(end);
|
|
}
|
|
Lexer L;
|
|
size_t next_id = 0;
|
|
};
|
|
|
|
} // namespace script
|
|
|
|
namespace {
|
|
|
|
std::string canonicalSchemaString(const FunctionSchema& schema) {
|
|
std::ostringstream out;
|
|
|
|
out << schema.name();
|
|
out << "(";
|
|
|
|
bool seen_kwarg_only = false;
|
|
for(size_t i = 0; i < schema.arguments().size(); ++i) {
|
|
if (i > 0) out << ", ";
|
|
if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
|
|
out << "*, ";
|
|
seen_kwarg_only = true;
|
|
}
|
|
const auto & arg = schema.arguments()[i];
|
|
out << arg.type()->str() << " " << arg.name();
|
|
}
|
|
|
|
out << ") -> ";
|
|
if (schema.returns().size() == 1) {
|
|
out << schema.returns().at(0).type()->str();
|
|
} else if (schema.returns().size() > 1) {
|
|
out << "(";
|
|
for (size_t i = 0; i < schema.returns().size(); ++i) {
|
|
if (i > 0) out << ", ";
|
|
out << schema.returns()[i].type()->str();
|
|
}
|
|
out << ")";
|
|
}
|
|
return out.str();
|
|
}
|
|
|
|
using OperatorMap = std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
|
|
struct OperatorRegistry {
|
|
private:
|
|
std::mutex lock;
|
|
OperatorMap operators;
|
|
// list of operators whose schema have not yet been parsed, and must
|
|
// be registered before any call to lookup an opeator
|
|
std::vector<std::shared_ptr<Operator>> to_register;
|
|
// Those two maps are used to implement lookupByLiteral, which is needed for the n->match(...) calls.
|
|
// Basically, every function schema is assigned a unique string you can use to match it. However,
|
|
// parsing those strings or comparing and hashing them character by character would be very slow, so
|
|
// we use a trick here! Every string literal in your program is guaranteed to have static storage
|
|
// duration and so its address won't change at runtime. This allows us to memoize answers for every
|
|
// pointer, which is done by the operators_by_sig_literal map. Still, this map is initially
|
|
// empty, and so we still need to do the complete string matching at the first time, which is implemented
|
|
// by performing a lookup in the operators_by_sig map.
|
|
std::unordered_map<std::string, std::shared_ptr<Operator>> operators_by_sig;
|
|
std::unordered_map<const char *, std::shared_ptr<Operator>> operators_by_sig_literal;
|
|
|
|
// XXX - caller must be holding lock
|
|
void registerPendingOperators() {
|
|
for(const auto& op : to_register) {
|
|
Symbol sym = Symbol::fromQualString(op->schema().name());
|
|
operators[sym].push_back(op);
|
|
operators_by_sig[canonicalSchemaString(op->schema())] = op;
|
|
}
|
|
to_register.clear();
|
|
}
|
|
|
|
public:
|
|
void registerOperator(Operator&& op) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
to_register.push_back(std::make_shared<Operator>(std::move(op)));
|
|
}
|
|
|
|
const std::shared_ptr<Operator>& lookupByLiteral(const char * name) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
registerPendingOperators();
|
|
auto it = operators_by_sig_literal.find(name);
|
|
if (it == operators_by_sig_literal.end()) {
|
|
auto op_ptr_it = operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
|
|
// Handy debugging code that dumps all operators we know about on mismatch
|
|
#if 0
|
|
if (op_ptr_it == operators_by_sig.end()) {
|
|
for (auto & entry : operators_by_sig) {
|
|
std::cout << entry.first << std::endl;
|
|
}
|
|
}
|
|
#endif
|
|
JIT_ASSERTM(op_ptr_it != operators_by_sig.end(), "Couldn't find an operator for ", name);
|
|
it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second);
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
|
|
const std::vector<std::shared_ptr<Operator>>& getOperators(Symbol name) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
registerPendingOperators();
|
|
static std::vector<std::shared_ptr<Operator>> empty;
|
|
auto it = operators.find(name);
|
|
if(it != operators.end())
|
|
return it->second;
|
|
return empty;
|
|
}
|
|
};
|
|
|
|
OperatorRegistry& getRegistry() {
|
|
static OperatorRegistry r;
|
|
return r;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
void registerOperator(Operator&& op) {
|
|
if(op.schema().is_varret()) {
|
|
Symbol s = Symbol::fromQualString(op.schema().name());
|
|
if (!printerHasSpecialCaseFor(s)) {
|
|
std::cout << c10::str(
|
|
"missing special case in python printer for non-schematized operator ",
|
|
op.schema().name(),
|
|
". File a bug to add a case for this operator.\n");
|
|
}
|
|
}
|
|
|
|
getRegistry().registerOperator(std::move(op));
|
|
}
|
|
|
|
const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name) {
|
|
return getRegistry().getOperators(name);
|
|
}
|
|
|
|
Operator& sig(const char *signature) {
|
|
return *getRegistry().lookupByLiteral(signature);
|
|
}
|
|
|
|
FunctionSchema parseSchema(const std::string& schema) {
|
|
return script::SchemaParser(schema).parseDeclarations().at(0);
|
|
}
|
|
|
|
bool Operator::matches(const Node* node) const {
|
|
// wrong name
|
|
if (node->kind().toQualString() != schema().name()) {
|
|
return false;
|
|
}
|
|
at::ArrayRef<const Value*> actuals = node->inputs();
|
|
const auto& formals = schema().arguments();
|
|
|
|
// not enough inputs
|
|
if(actuals.size() < formals.size())
|
|
return false;
|
|
|
|
|
|
TypeEnv type_env;
|
|
for(size_t i = 0; i < formals.size(); ++i) {
|
|
const MatchTypeReturn matched_type =
|
|
matchTypeVariables(formals[i].type(), actuals[i]->type(), type_env);
|
|
if (!matched_type.type) {
|
|
return false;
|
|
}
|
|
TypePtr formal = *matched_type.type;
|
|
if (!actuals[i]->type()->isSubtypeOf(formal)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// too many inputs
|
|
if(!schema().is_vararg() && actuals.size() != formals.size()) {
|
|
// std::cout << "not all inputs used\n" << input_i << " " << inputs_size << "\n";
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
std::shared_ptr<Operator> findOperatorFor(const Node* node) {
|
|
const auto& candidates = getAllOperatorsFor(node->kind());
|
|
for(const auto& candidate : candidates) {
|
|
if(candidate->matches(node)) {
|
|
return candidate;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
const Operator& getOperatorFor(const Node* node) {
|
|
auto op = findOperatorFor(node);
|
|
if(op)
|
|
return *op;
|
|
|
|
auto er = script::ErrorReport(node->getSourceLocation());
|
|
er << "Schema not found for node. File a bug report.\n";
|
|
er << "Node: " << *node << "\n";
|
|
er << "Input types:";
|
|
for(size_t i = 0; i < node->inputs().size(); ++i) {
|
|
if(i > 0)
|
|
er << ", ";
|
|
er << *node->inputs()[i]->type();
|
|
}
|
|
er << "\ncandidates were:\n";
|
|
const auto& candidates = getAllOperatorsFor(node->kind());
|
|
for(auto & candidate : candidates) {
|
|
er << " " << candidate->schema() << "\n";
|
|
}
|
|
er << *node->owningGraph() << "\n";
|
|
throw er;
|
|
}
|
|
|
|
|
|
OperatorSet::OperatorSet(std::initializer_list<const char *> sig_literals) {
|
|
auto & registry = getRegistry();
|
|
for (const char * sig : sig_literals) {
|
|
auto op = registry.lookupByLiteral(sig);
|
|
ops[Symbol::fromQualString(op->schema().name())].push_back(op);
|
|
}
|
|
}
|
|
|
|
Operator* OperatorSet::find(const Node *n) const {
|
|
auto it = ops.find(n->kind());
|
|
if (it == ops.end()) {
|
|
return nullptr;
|
|
}
|
|
for (auto & op : it->second) {
|
|
if (op->matches(n)) {
|
|
return op.get();
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
}}
|