mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enable all clang-tidy performance checks (#15198)
Summary: This PR adds the final set of clang-tidy checks we should add for our codebase: a last set of performance-related checks. Most fixes here are around changing `auto` to `const auto&` in a few places where unnecessary copies were made, and adding `reserve()` calls before loops doing repeated `push_back()`. Also a few cases of calling `std::string::find` with a single-character string literal instead of a single char, which uses a less efficient string search algorithm meant for searching larger substrings.  ezyang apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/15198 Differential Revision: D13468797 Pulled By: goldsborough fbshipit-source-id: 2bed1ea1c7c162b7f3e0e1026f17125e88c4d5b2
This commit is contained in:
committed by
Facebook Github Bot
parent
fc2856e9aa
commit
7a61306031
@ -24,7 +24,8 @@ Checks: '
|
||||
,-modernize-use-auto
|
||||
,-modernize-use-default-member-init
|
||||
,-modernize-use-using
|
||||
,performance-unnecessary-value-param
|
||||
,performance-*
|
||||
,-performance-noexcept-move-constructor
|
||||
'
|
||||
WarningsAsErrors: '*'
|
||||
HeaderFilterRegex: 'torch/csrc/.*'
|
||||
|
@ -50,7 +50,7 @@ struct python_error : public std::exception {
|
||||
}
|
||||
|
||||
python_error(python_error&& other) {
|
||||
type = std::move(other.type);
|
||||
type = other.type;
|
||||
value = other.value;
|
||||
traceback = other.traceback;
|
||||
other.type = nullptr;
|
||||
|
@ -150,7 +150,7 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
|
||||
/// \endrst
|
||||
void apply(
|
||||
const NamedModuleApplyFunction& function,
|
||||
std::string name_prefix = std::string());
|
||||
const std::string& name_prefix = std::string());
|
||||
|
||||
/// Applies the `function` to the `Module` and recursively to every submodule.
|
||||
/// The function must accept a `const std::string&` for the key of the module,
|
||||
@ -167,7 +167,7 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
|
||||
/// \endrst
|
||||
void apply(
|
||||
const ConstNamedModuleApplyFunction& function,
|
||||
std::string name_prefix = std::string()) const;
|
||||
const std::string& name_prefix = std::string()) const;
|
||||
|
||||
/// Applies the `function` to the `Module` and recursively to every submodule.
|
||||
/// The function must accept a `const std::shared_ptr<Module>&`.
|
||||
@ -198,7 +198,7 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
|
||||
/// \endrst
|
||||
void apply(
|
||||
const NamedModulePointerApplyFunction& function,
|
||||
std::string name_prefix = std::string()) const;
|
||||
const std::string& name_prefix = std::string()) const;
|
||||
|
||||
/// Returns the parameters of this `Module` and if `recurse` is true, also
|
||||
/// recursively of every submodule.
|
||||
@ -243,7 +243,7 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
|
||||
/// stored in a `shared_ptr`.
|
||||
/// \endrst
|
||||
OrderedDict<std::string, std::shared_ptr<Module>> named_modules(
|
||||
std::string name_prefix = std::string(),
|
||||
const std::string& name_prefix = std::string(),
|
||||
bool include_self = true) const;
|
||||
|
||||
/// Returns the direct submodules of this `Module`.
|
||||
|
@ -51,11 +51,11 @@ uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
|
||||
return value;
|
||||
}
|
||||
|
||||
std::string join_paths(std::string head, std::string tail) {
|
||||
std::string join_paths(std::string head, const std::string& tail) {
|
||||
if (head.back() != '/') {
|
||||
head.push_back('/');
|
||||
}
|
||||
head += std::move(tail);
|
||||
head += tail;
|
||||
return head;
|
||||
}
|
||||
|
||||
|
@ -101,26 +101,26 @@ void Module::apply(const ConstModuleApplyFunction& function) const {
|
||||
|
||||
void Module::apply(
|
||||
const NamedModuleApplyFunction& function,
|
||||
std::string name_prefix) {
|
||||
const std::string& name_prefix) {
|
||||
function(/*name=*/name_prefix, *this);
|
||||
apply_to_submodules(
|
||||
[&function](
|
||||
const std::string& name, const std::shared_ptr<Module>& module) {
|
||||
function(name, *module);
|
||||
},
|
||||
std::move(name_prefix));
|
||||
name_prefix);
|
||||
}
|
||||
|
||||
void Module::apply(
|
||||
const ConstNamedModuleApplyFunction& function,
|
||||
std::string name_prefix) const {
|
||||
const std::string& name_prefix) const {
|
||||
function(/*name=*/name_prefix, *this);
|
||||
apply_to_submodules(
|
||||
[&function](
|
||||
const std::string& name, const std::shared_ptr<Module>& module) {
|
||||
function(name, *module);
|
||||
},
|
||||
std::move(name_prefix));
|
||||
name_prefix);
|
||||
}
|
||||
|
||||
void Module::apply(const ModulePointerApplyFunction& function) const {
|
||||
@ -133,10 +133,10 @@ void Module::apply(const ModulePointerApplyFunction& function) const {
|
||||
|
||||
void Module::apply(
|
||||
const NamedModulePointerApplyFunction& function,
|
||||
std::string name_prefix) const {
|
||||
const std::string& name_prefix) const {
|
||||
function(
|
||||
/*name=*/name_prefix, shared_from_this_checked());
|
||||
apply_to_submodules(function, std::move(name_prefix));
|
||||
apply_to_submodules(function, name_prefix);
|
||||
}
|
||||
|
||||
std::vector<Tensor> Module::parameters(bool recurse) const {
|
||||
@ -199,7 +199,7 @@ std::vector<std::shared_ptr<Module>> Module::modules(bool include_self) const {
|
||||
}
|
||||
|
||||
OrderedDict<std::string, std::shared_ptr<Module>> Module::named_modules(
|
||||
std::string name_prefix,
|
||||
const std::string& name_prefix,
|
||||
bool include_self) const {
|
||||
OrderedDict<std::string, std::shared_ptr<Module>> result;
|
||||
if (include_self) {
|
||||
@ -208,14 +208,14 @@ OrderedDict<std::string, std::shared_ptr<Module>> Module::named_modules(
|
||||
const std::string& key, const std::shared_ptr<Module>& module) {
|
||||
result.insert(key, module);
|
||||
},
|
||||
std::move(name_prefix));
|
||||
name_prefix);
|
||||
} else {
|
||||
apply_to_submodules(
|
||||
[&result](
|
||||
const std::string& key, const std::shared_ptr<Module>& module) {
|
||||
result.insert(key, module);
|
||||
},
|
||||
std::move(name_prefix));
|
||||
name_prefix);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -329,7 +329,7 @@ void Module::apply_to_submodules(
|
||||
for (const auto& child : children_) {
|
||||
auto qualified_name = join_name(name_prefix, child.key());
|
||||
function(qualified_name, child.value());
|
||||
child.value()->apply_to_submodules(function, std::move(qualified_name));
|
||||
child.value()->apply_to_submodules(function, qualified_name);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -13,8 +13,7 @@ namespace torch {
|
||||
namespace nn {
|
||||
BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {}
|
||||
|
||||
BatchNormImpl::BatchNormImpl(BatchNormOptions options)
|
||||
: options(std::move(options)) {
|
||||
BatchNormImpl::BatchNormImpl(BatchNormOptions options) : options(options) {
|
||||
reset();
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,7 @@ EmbeddingOptions::EmbeddingOptions(int64_t count, int64_t dimension)
|
||||
: count_(count), dimension_(dimension) {}
|
||||
|
||||
EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options)
|
||||
: options(std::move(options)) {
|
||||
: options(options) {
|
||||
reset();
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,7 @@ namespace torch {
|
||||
namespace nn {
|
||||
LinearOptions::LinearOptions(int64_t in, int64_t out) : in_(in), out_(out) {}
|
||||
|
||||
LinearImpl::LinearImpl(LinearOptions options) : options(std::move(options)) {
|
||||
LinearImpl::LinearImpl(LinearOptions options) : options(options) {
|
||||
reset();
|
||||
}
|
||||
|
||||
|
@ -131,7 +131,7 @@ RNNOutput RNNImplBase<Derived>::generic_forward(
|
||||
}
|
||||
Tensor output, new_state;
|
||||
std::tie(output, new_state) = function(
|
||||
std::move(input),
|
||||
input,
|
||||
std::move(state),
|
||||
flat_weights_,
|
||||
options.with_bias_,
|
||||
@ -208,12 +208,12 @@ RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) {
|
||||
case RNNActivation::ReLU:
|
||||
return generic_forward(
|
||||
static_cast<RNNFunctionSignature*>(&torch::rnn_relu),
|
||||
std::move(input),
|
||||
input,
|
||||
std::move(state));
|
||||
case RNNActivation::Tanh:
|
||||
return generic_forward(
|
||||
static_cast<RNNFunctionSignature*>(&torch::rnn_tanh),
|
||||
std::move(input),
|
||||
input,
|
||||
std::move(state));
|
||||
default:
|
||||
AT_ERROR("Unhandled RNN activation function!");
|
||||
@ -244,7 +244,7 @@ RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) {
|
||||
}
|
||||
Tensor output, hidden_state, cell_state;
|
||||
std::tie(output, hidden_state, cell_state) = torch::lstm(
|
||||
std::move(input),
|
||||
input,
|
||||
{state[0], state[1]},
|
||||
flat_weights_,
|
||||
options.with_bias_,
|
||||
@ -266,9 +266,7 @@ GRUImpl::GRUImpl(const GRUOptions& options)
|
||||
|
||||
RNNOutput GRUImpl::forward(const Tensor& input, Tensor state) {
|
||||
return generic_forward(
|
||||
static_cast<RNNFunctionSignature*>(&torch::gru),
|
||||
std::move(input),
|
||||
std::move(state));
|
||||
static_cast<RNNFunctionSignature*>(&torch::gru), input, std::move(state));
|
||||
}
|
||||
} // namespace nn
|
||||
} // namespace torch
|
||||
|
@ -17,6 +17,7 @@ void serialize(
|
||||
const std::string& key,
|
||||
const std::vector<int64_t>& steps) {
|
||||
std::vector<torch::Tensor> tensors;
|
||||
tensors.reserve(steps.size());
|
||||
for (const auto& step : steps) {
|
||||
tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
|
||||
}
|
||||
|
@ -153,7 +153,7 @@ inline Tensor as_variable(Tensor tensor) {
|
||||
|
||||
inline std::vector<Tensor> as_variable(TensorList tl) {
|
||||
return fmap(tl, [](const Tensor& t) -> Tensor {
|
||||
return make_variable(std::move(t), /*requires_grad=*/false);
|
||||
return make_variable(t, /*requires_grad=*/false);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -408,7 +408,7 @@ static variable_list call_function(FunctionTask& task) {
|
||||
|
||||
if(has_post_hooks){
|
||||
// NOLINTNEXTLINE(bugprone-use-after-move)
|
||||
return call_post_hooks(fn, std::move(outputs), std::move(inputs));
|
||||
return call_post_hooks(fn, std::move(outputs), inputs);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
@ -585,10 +585,7 @@ static Node* _trace_pre_record(
|
||||
Py_INCREF(op_obj);
|
||||
auto pyobj = THPObjectPtr(op_obj);
|
||||
return jit::tracer::preRecordPythonTrace(
|
||||
std::move(pyobj),
|
||||
std::move(arg_types),
|
||||
input_vars,
|
||||
std::move(scalar_args));
|
||||
std::move(pyobj), arg_types, input_vars, std::move(scalar_args));
|
||||
}
|
||||
|
||||
static void _trace_post_record(
|
||||
|
@ -31,7 +31,7 @@ BatchTensor::BatchTensor(const std::vector<at::Tensor>& datalist, at::Tensor dim
|
||||
sizes[0] = bs;
|
||||
mask_sizes[0] = bs;
|
||||
for(int64_t i = 1; i < dims.size(0) + 1; i++){
|
||||
for(auto x : datalist){
|
||||
for(const auto& x : datalist){
|
||||
sizes[i] = std::max(sizes[i], x.size(i));
|
||||
}
|
||||
mask_sizes[i] = *dims[i - 1].data<uint8_t>() ? sizes[i] : 1;
|
||||
|
@ -102,19 +102,19 @@ RegisterOperators reg({
|
||||
return 0;
|
||||
};
|
||||
} else if(type->isSubtypeOf(ListType::ofInts())) {
|
||||
auto is = node->is(attr::value);
|
||||
const auto& is = node->is(attr::value);
|
||||
return [is](Stack& stack) {
|
||||
push(stack, is);
|
||||
return 0;
|
||||
};
|
||||
} else if(type->isSubtypeOf(ListType::ofBools())) {
|
||||
auto bs = node->is(attr::value);
|
||||
const auto& bs = node->is(attr::value);
|
||||
return [bs](Stack& stack) {
|
||||
push(stack, bs);
|
||||
return 0;
|
||||
};
|
||||
} else if(type->isSubtypeOf(ListType::ofTensors())) {
|
||||
auto ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor {
|
||||
const auto& ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor {
|
||||
return autograd::make_variable(t);
|
||||
});
|
||||
return [ts](Stack& stack) {
|
||||
@ -122,7 +122,7 @@ RegisterOperators reg({
|
||||
return 0;
|
||||
};
|
||||
} else if (type == StringType::get()) {
|
||||
auto s = node->s(attr::value);
|
||||
const auto& s = node->s(attr::value);
|
||||
return [s](Stack& stack) {
|
||||
push(stack, s);
|
||||
return 0;
|
||||
|
@ -63,7 +63,7 @@ Node* getTracedNode(
|
||||
const std::tuple<Types...>& tuple) {
|
||||
auto symbol = Symbol::fromQualString(schema.name());
|
||||
const auto& graph = tracer::getTracingState()->graph;
|
||||
Node* node = graph->create(std::move(symbol), /*num_outputs=*/0);
|
||||
Node* node = graph->create(symbol, /*num_outputs=*/0);
|
||||
tracer::recordSourceLocation(node);
|
||||
|
||||
// Hack to call addInputs for the parameter pack in a sequenced fashion.
|
||||
|
@ -51,7 +51,7 @@ int debugFuser() {
|
||||
// If the given node is used once by a chunk node, returns that node.
|
||||
// Returns nullptr otherwise.
|
||||
static const Node* usedInFusedChunk(const Value* input) {
|
||||
const auto uses = input->uses();
|
||||
const auto& uses = input->uses();
|
||||
if (uses.size() == 1) {
|
||||
const Node *user = uses[0].user;
|
||||
if (user->kind() == prim::ConstantChunk) {
|
||||
|
@ -503,7 +503,7 @@ private:
|
||||
}
|
||||
|
||||
void runTraced(Stack & stack) {
|
||||
auto state = tracer::getTracingState();
|
||||
const auto& state = tracer::getTracingState();
|
||||
auto inputs = last(stack, num_inputs);
|
||||
auto input_values = fmap(inputs, [](const IValue & v) {
|
||||
return tracer::getNestedValueTrace(v);
|
||||
|
@ -300,7 +300,7 @@ void initJITBindings(PyObject *module) {
|
||||
m.def("_jit_get_operation", [](const std::string& qualified_name) {
|
||||
try {
|
||||
auto symbol = Symbol::fromQualString(qualified_name);
|
||||
auto operations = getAllOperatorsFor(std::move(symbol));
|
||||
auto operations = getAllOperatorsFor(symbol);
|
||||
AT_CHECK(!operations.empty(), "No such operator ", qualified_name);
|
||||
AT_CHECK(
|
||||
operations.size() == 1,
|
||||
@ -338,7 +338,7 @@ void initJITBindings(PyObject *module) {
|
||||
});
|
||||
m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) {
|
||||
auto symbol = Symbol::fromQualString(qualified_name);
|
||||
auto operations = getAllOperatorsFor(std::move(symbol));
|
||||
auto operations = getAllOperatorsFor(symbol);
|
||||
return fmap(operations, [](const std::shared_ptr<Operator>& op) {
|
||||
return op->schema();
|
||||
});
|
||||
|
@ -73,7 +73,7 @@ struct Suspend : public std::exception {
|
||||
|
||||
struct InterpreterContinuation {
|
||||
InterpreterContinuation(InterpreterState state_, Stack stack_)
|
||||
: state(std::move(state_)), stack(std::move(stack_)) {}
|
||||
: state(state_), stack(std::move(stack_)) {}
|
||||
|
||||
void operator()() {
|
||||
state.runAsync(stack);
|
||||
|
@ -253,7 +253,7 @@ struct SchemaParser {
|
||||
n = "-" + L.expect(TK_NUMBER).text();
|
||||
else
|
||||
n = L.expect(TK_NUMBER).text();
|
||||
if(kind == TypeKind::FloatType || n.find(".") != std::string::npos || n.find("e") != std::string::npos) {
|
||||
if(kind == TypeKind::FloatType || n.find('.') != std::string::npos || n.find('e') != std::string::npos) {
|
||||
return std::stod(n);
|
||||
} else {
|
||||
int64_t v = std::stoll(n);
|
||||
@ -405,7 +405,7 @@ private:
|
||||
|
||||
// XXX - caller must be holding lock
|
||||
void registerPendingOperators() {
|
||||
for(auto op : to_register) {
|
||||
for(const auto& op : to_register) {
|
||||
Symbol sym = Symbol::fromQualString(op->schema().name());
|
||||
operators[sym].push_back(op);
|
||||
operators_by_sig[canonicalSchemaString(op->schema())] = op;
|
||||
|
@ -531,7 +531,7 @@ void AliasDb::addAlias(const Value* value, Symbol alias) {
|
||||
valueToAlias_[value].addSet(alias);
|
||||
} else {
|
||||
AliasInfo aliasInfo;
|
||||
aliasInfo.addSet(std::move(alias));
|
||||
aliasInfo.addSet(alias);
|
||||
valueToAlias_.insert({value, std::move(aliasInfo)});
|
||||
}
|
||||
}
|
||||
|
@ -973,6 +973,7 @@ void PeepholeOptimizeShapeExpressions(Block * block) {
|
||||
}
|
||||
if (unique_to_value.size() != node->inputs().size()) {
|
||||
std::vector<Value*> inputs;
|
||||
inputs.reserve(unique_to_value.size());
|
||||
for (auto & entry : unique_to_value) {
|
||||
inputs.push_back(entry.second);
|
||||
}
|
||||
|
@ -637,7 +637,7 @@ struct PythonPrintPass {
|
||||
} else if(v.isTensorList()) {
|
||||
stmt << "[";
|
||||
const char* delim = "";
|
||||
for(auto t : v.toTensorListRef()) {
|
||||
for(const auto& t : v.toTensorListRef()) {
|
||||
stmt << delim << "CONSTANTS.c" << getOrAddTensorConstant(t);
|
||||
delim = ", ";
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ void specializeUndef(Graph & g) {
|
||||
std::unordered_map<Value*, State> state;
|
||||
|
||||
for (Value* input : g.inputs()) {
|
||||
auto tp = input->type();
|
||||
const auto& tp = input->type();
|
||||
if (tp->isSubtypeOf(UndefinedTensorType::get())) {
|
||||
state[input] = State::Undefined;
|
||||
} else if (tp->isSubtypeOf(DynamicType::get())) {
|
||||
|
@ -467,7 +467,7 @@ void initPythonIRBindings(PyObject * module_) {
|
||||
.def(py::init([](std::vector<TypePtr> a){ return TupleType::create(a); }))
|
||||
.def("elements", [](TupleType &self){
|
||||
std::vector<TypePtr> types;
|
||||
for (auto type : self.elements()) {
|
||||
for (const auto& type : self.elements()) {
|
||||
types.push_back(type);
|
||||
}
|
||||
return types;
|
||||
|
@ -165,7 +165,7 @@ void initPythonTracerBindings(PyObject* module) {
|
||||
return setValueTrace(var, value);
|
||||
});
|
||||
m.def("_tracer_set_get_unique_name_fn", [](py::function func) {
|
||||
auto tracing_state = getTracingState();
|
||||
const auto& tracing_state = getTracingState();
|
||||
JIT_ASSERT(tracing_state);
|
||||
tracing_state->lookup_var_name_fn = [func](const Variable& var) -> std::string {
|
||||
AutoGIL ag;
|
||||
@ -173,7 +173,7 @@ void initPythonTracerBindings(PyObject* module) {
|
||||
};
|
||||
});
|
||||
m.def("_tracer_set_force_outplace", [](bool force_outplace) {
|
||||
auto tracing_state = getTracingState();
|
||||
const auto& tracing_state = getTracingState();
|
||||
JIT_ASSERT(tracing_state);
|
||||
tracing_state->force_outplace = force_outplace;
|
||||
});
|
||||
|
@ -401,7 +401,7 @@ RegisterOperators reg({
|
||||
return [=](Stack& stack) {
|
||||
bool result = false;
|
||||
for (const IValue& t : last(stack, num_inputs)) {
|
||||
if (std::move(t).toTensor().defined()) {
|
||||
if (t.toTensor().defined()) {
|
||||
result = true;
|
||||
break;
|
||||
}
|
||||
@ -1135,6 +1135,7 @@ Operator( \
|
||||
at::Tensor t;
|
||||
pop(stack, t);
|
||||
std::vector<int64_t> elems;
|
||||
elems.reserve(t.size(0));
|
||||
for(int i = 0; i < t.size(0); i++){
|
||||
elems.push_back(*t[i].data<int32_t>());
|
||||
}
|
||||
|
@ -419,7 +419,7 @@ static inline bool isIntOrFloatUsedAsList(
|
||||
const Value* value,
|
||||
const Argument& arg) {
|
||||
// Look for int[N] or float[N]
|
||||
auto v_type = value->type();
|
||||
const auto& v_type = value->type();
|
||||
if (v_type != FloatType::get() && v_type != IntType::get())
|
||||
return false;
|
||||
auto arg_type = unwrapOptional(arg.type());
|
||||
@ -1054,7 +1054,7 @@ private:
|
||||
<< " return (" << schema.returns().size() << ") does not match"
|
||||
<< " the number of returns from the function (" << results.size() << ")!";
|
||||
}
|
||||
auto range = return_stmt.range();
|
||||
const auto& range = return_stmt.range();
|
||||
size_t return_type_idx = 0;
|
||||
for (auto r : results) {
|
||||
TypePtr type = DynamicType::get();
|
||||
@ -1506,7 +1506,7 @@ private:
|
||||
auto instances = sv->asTuple(stmt.range(), method);
|
||||
const std::string& target_name = target.name();
|
||||
pushFrame(environment_stack->block());
|
||||
for(auto inst : instances) {
|
||||
for(const auto& inst : instances) {
|
||||
environment_stack->setSugaredVar(itrs[0].range(), target_name, inst);
|
||||
emitStatements(body);
|
||||
}
|
||||
@ -1988,7 +1988,7 @@ private:
|
||||
if(maybe_unpack && tree->kind() == TK_STARRED) {
|
||||
auto starred = Starred(tree);
|
||||
auto entries = emitSugaredExpr(starred.expr(), 1)->asTuple(starred.range(), method);
|
||||
for(auto entry : entries) {
|
||||
for(const auto& entry : entries) {
|
||||
values.emplace_back(
|
||||
tree->range(), entry->asValue(starred.range(), method));
|
||||
}
|
||||
@ -2639,7 +2639,7 @@ void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::vector<D
|
||||
auto resolver_it = resolvers.begin();
|
||||
std::vector<Method*> methods;
|
||||
std::unordered_map<std::string, Method*> function_table;
|
||||
for(Def def : definitions) {
|
||||
for(const Def& def : definitions) {
|
||||
const std::string& name = def.name().name();
|
||||
auto resolver = *resolver_it++;
|
||||
JIT_ASSERT(resolver);
|
||||
|
@ -121,7 +121,7 @@ private:
|
||||
|
||||
struct TORCH_API BuiltinFunction : public SugaredValue {
|
||||
BuiltinFunction(Symbol symbol, c10::optional<NamedValue> self)
|
||||
: symbol(std::move(symbol)), self(std::move(self)) {}
|
||||
: symbol(symbol), self(std::move(self)) {}
|
||||
|
||||
// The symbol of the function (e.g. `aten::relu`).
|
||||
Symbol symbol;
|
||||
|
@ -185,7 +185,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
|
||||
const SourceRange& loc,
|
||||
Method& m) override {
|
||||
std::vector<Value*> values;
|
||||
for (auto sugared_item : asTuple(loc, m)) {
|
||||
for (const auto& sugared_item : asTuple(loc, m)) {
|
||||
values.push_back(sugared_item->asValue(loc, m));
|
||||
}
|
||||
auto node = m.graph()->createTuple(values);
|
||||
@ -532,6 +532,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||
const std::vector<ResolutionCallback>& rcbs,
|
||||
const std::vector<FunctionDefaults>& defaults) {
|
||||
std::vector<Resolver> resolvers;
|
||||
resolvers.reserve(rcbs.size());
|
||||
for(auto & callback : rcbs) {
|
||||
resolvers.push_back(pythonResolver(callback));
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ struct Method {
|
||||
stack.push_back(*inp);
|
||||
}
|
||||
const auto size = stack.size();
|
||||
setInputTypes(*retval, ArgumentSpec(with_grad, std::move(stack), size));
|
||||
setInputTypes(*retval, ArgumentSpec(with_grad, stack, size));
|
||||
PropagateInputShapes(retval);
|
||||
return retval;
|
||||
}
|
||||
|
@ -107,10 +107,10 @@ void initTreeViewBindings(PyObject *module) {
|
||||
.def(py::init([](const Ident& name,
|
||||
Decl decl,
|
||||
std::vector<Stmt> body) {
|
||||
auto r = name.range();
|
||||
const auto& r = name.range();
|
||||
return Def::create(r,
|
||||
name,
|
||||
std::move(decl),
|
||||
decl,
|
||||
wrap_list(r, std::move(body)));
|
||||
}));
|
||||
py::class_<Decl, TreeView>(m, "Decl")
|
||||
@ -127,7 +127,7 @@ void initTreeViewBindings(PyObject *module) {
|
||||
}));
|
||||
py::class_<AugAssign, Stmt>(m, "AugAssign")
|
||||
.def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) {
|
||||
auto r = lhs.range();
|
||||
const auto& r = lhs.range();
|
||||
auto kind = AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
|
||||
return AugAssign::create(r, lhs, kind, rhs);
|
||||
}));
|
||||
@ -198,13 +198,13 @@ void initTreeViewBindings(PyObject *module) {
|
||||
}));
|
||||
py::class_<Apply, Expr>(m, "Apply")
|
||||
.def(py::init([](const Expr& expr, std::vector<Expr> args, std::vector<Attribute> kwargs) {
|
||||
auto r = expr.range();
|
||||
const auto& r = expr.range();
|
||||
return Apply::create(expr.range(), expr,
|
||||
wrap_list(r, std::move(args)), wrap_list(r, std::move(kwargs)));
|
||||
}));
|
||||
py::class_<Select, Expr>(m, "Select")
|
||||
.def(py::init([](const Expr& expr, const Ident& field) {
|
||||
auto r = expr.range();
|
||||
const auto& r = expr.range();
|
||||
return Select::create(expr.range(), expr, field);
|
||||
}));
|
||||
py::class_<TernaryIf, Expr>(m, "TernaryIf")
|
||||
|
@ -111,7 +111,7 @@ struct String : public Tree {
|
||||
};
|
||||
|
||||
static SourceRange mergeRanges(SourceRange c, const TreeList& others) {
|
||||
for (auto t : others) {
|
||||
for (const auto& t : others) {
|
||||
if (t->isAtom())
|
||||
continue;
|
||||
size_t s = std::min(c.start(), t->range().start());
|
||||
@ -171,7 +171,7 @@ struct pretty_tree {
|
||||
break;
|
||||
default:
|
||||
out << "(" << kindToString(t->kind());
|
||||
for (auto e : t->trees()) {
|
||||
for (const auto& e : t->trees()) {
|
||||
out << " " << get_flat(e);
|
||||
}
|
||||
out << ")";
|
||||
@ -188,7 +188,7 @@ struct pretty_tree {
|
||||
}
|
||||
std::string k = kindToString(t->kind());
|
||||
out << "(" << k;
|
||||
for (auto e : t->trees()) {
|
||||
for (const auto& e : t->trees()) {
|
||||
out << "\n" << std::string(indent + 2, ' ');
|
||||
print(out, e, indent + 2);
|
||||
}
|
||||
|
@ -290,7 +290,7 @@ static void py_bind_tensor_types(const std::vector<PyTensorType>& tensor_types)
|
||||
|
||||
for (auto& tensor_type : tensor_types) {
|
||||
auto name = std::string(tensor_type.name);
|
||||
auto idx = name.rfind(".");
|
||||
auto idx = name.rfind('.');
|
||||
auto type_name = name.substr(idx + 1);
|
||||
auto module_name = name.substr(0, idx);
|
||||
|
||||
|
@ -286,7 +286,7 @@ FunctionSignature::FunctionSignature(const std::string& fmt)
|
||||
while (!done) {
|
||||
auto offset = fmt.find(", ", last_offset);
|
||||
if (offset == std::string::npos) {
|
||||
offset = fmt.find(")", last_offset);
|
||||
offset = fmt.find(')', last_offset);
|
||||
done = true;
|
||||
next_offset = offset + 1;
|
||||
} else {
|
||||
|
@ -209,7 +209,7 @@ Tensor internal_new_from_data(
|
||||
auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(torch::getDeviceType(type)));
|
||||
AutoNoGIL no_gil;
|
||||
maybe_initialize_cuda(device);
|
||||
return var.to(device, scalar_type, /*blocking=*/false, /*copy=*/copy_variables);
|
||||
return var.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/copy_variables);
|
||||
}
|
||||
|
||||
#ifdef USE_NUMPY
|
||||
@ -219,7 +219,7 @@ Tensor internal_new_from_data(
|
||||
auto device = device_opt.has_value() ? *device_opt : at::Device(type.device_type());
|
||||
AutoNoGIL no_gil;
|
||||
maybe_initialize_cuda(device);
|
||||
return tensor.to(device, scalar_type, /*blocking=*/false, /*copy=*/copy_numpy);
|
||||
return tensor.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -232,7 +232,7 @@ Tensor internal_new_from_data(
|
||||
auto device = device_opt.has_value() ? *device_opt : at::Device(torch::getDeviceType(type));
|
||||
AutoNoGIL no_gil;
|
||||
maybe_initialize_cuda(device);
|
||||
return tensor.to(device, scalar_type, /*blocking=*/false, /*copy=*/false);
|
||||
return tensor.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/false);
|
||||
}
|
||||
|
||||
Tensor new_from_data_copy(
|
||||
|
Reference in New Issue
Block a user