Enable all clang-tidy performance checks (#15198)

Summary:
This PR adds the final set of clang-tidy checks we should add for our codebase: a last set of performance-related checks. Most fixes here are around changing `auto` to `const auto&` in a few places where unnecessary copies were made, and adding `reserve()` calls before loops doing repeated `push_back()`. Also a few cases of calling `std::string::find` with a single-character string literal instead of a single char, which uses a less efficient string search algorithm meant for searching larger substrings.

![image](https://user-images.githubusercontent.com/6429851/49978940-adc1a780-ff01-11e8-99da-a4e431361f07.png)

ezyang apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15198

Differential Revision: D13468797

Pulled By: goldsborough

fbshipit-source-id: 2bed1ea1c7c162b7f3e0e1026f17125e88c4d5b2
This commit is contained in:
Peter Goldsborough
2018-12-14 13:30:35 -08:00
committed by Facebook Github Bot
parent fc2856e9aa
commit 7a61306031
37 changed files with 75 additions and 76 deletions

View File

@ -24,7 +24,8 @@ Checks: '
,-modernize-use-auto
,-modernize-use-default-member-init
,-modernize-use-using
,performance-unnecessary-value-param
,performance-*
,-performance-noexcept-move-constructor
'
WarningsAsErrors: '*'
HeaderFilterRegex: 'torch/csrc/.*'

View File

@ -50,7 +50,7 @@ struct python_error : public std::exception {
}
python_error(python_error&& other) {
type = std::move(other.type);
type = other.type;
value = other.value;
traceback = other.traceback;
other.type = nullptr;

View File

@ -150,7 +150,7 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
/// \endrst
void apply(
const NamedModuleApplyFunction& function,
std::string name_prefix = std::string());
const std::string& name_prefix = std::string());
/// Applies the `function` to the `Module` and recursively to every submodule.
/// The function must accept a `const std::string&` for the key of the module,
@ -167,7 +167,7 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
/// \endrst
void apply(
const ConstNamedModuleApplyFunction& function,
std::string name_prefix = std::string()) const;
const std::string& name_prefix = std::string()) const;
/// Applies the `function` to the `Module` and recursively to every submodule.
/// The function must accept a `const std::shared_ptr<Module>&`.
@ -198,7 +198,7 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
/// \endrst
void apply(
const NamedModulePointerApplyFunction& function,
std::string name_prefix = std::string()) const;
const std::string& name_prefix = std::string()) const;
/// Returns the parameters of this `Module` and if `recurse` is true, also
/// recursively of every submodule.
@ -243,7 +243,7 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
/// stored in a `shared_ptr`.
/// \endrst
OrderedDict<std::string, std::shared_ptr<Module>> named_modules(
std::string name_prefix = std::string(),
const std::string& name_prefix = std::string(),
bool include_self = true) const;
/// Returns the direct submodules of this `Module`.

View File

@ -51,11 +51,11 @@ uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
return value;
}
std::string join_paths(std::string head, std::string tail) {
std::string join_paths(std::string head, const std::string& tail) {
if (head.back() != '/') {
head.push_back('/');
}
head += std::move(tail);
head += tail;
return head;
}

View File

@ -101,26 +101,26 @@ void Module::apply(const ConstModuleApplyFunction& function) const {
void Module::apply(
const NamedModuleApplyFunction& function,
std::string name_prefix) {
const std::string& name_prefix) {
function(/*name=*/name_prefix, *this);
apply_to_submodules(
[&function](
const std::string& name, const std::shared_ptr<Module>& module) {
function(name, *module);
},
std::move(name_prefix));
name_prefix);
}
void Module::apply(
const ConstNamedModuleApplyFunction& function,
std::string name_prefix) const {
const std::string& name_prefix) const {
function(/*name=*/name_prefix, *this);
apply_to_submodules(
[&function](
const std::string& name, const std::shared_ptr<Module>& module) {
function(name, *module);
},
std::move(name_prefix));
name_prefix);
}
void Module::apply(const ModulePointerApplyFunction& function) const {
@ -133,10 +133,10 @@ void Module::apply(const ModulePointerApplyFunction& function) const {
void Module::apply(
const NamedModulePointerApplyFunction& function,
std::string name_prefix) const {
const std::string& name_prefix) const {
function(
/*name=*/name_prefix, shared_from_this_checked());
apply_to_submodules(function, std::move(name_prefix));
apply_to_submodules(function, name_prefix);
}
std::vector<Tensor> Module::parameters(bool recurse) const {
@ -199,7 +199,7 @@ std::vector<std::shared_ptr<Module>> Module::modules(bool include_self) const {
}
OrderedDict<std::string, std::shared_ptr<Module>> Module::named_modules(
std::string name_prefix,
const std::string& name_prefix,
bool include_self) const {
OrderedDict<std::string, std::shared_ptr<Module>> result;
if (include_self) {
@ -208,14 +208,14 @@ OrderedDict<std::string, std::shared_ptr<Module>> Module::named_modules(
const std::string& key, const std::shared_ptr<Module>& module) {
result.insert(key, module);
},
std::move(name_prefix));
name_prefix);
} else {
apply_to_submodules(
[&result](
const std::string& key, const std::shared_ptr<Module>& module) {
result.insert(key, module);
},
std::move(name_prefix));
name_prefix);
}
return result;
}
@ -329,7 +329,7 @@ void Module::apply_to_submodules(
for (const auto& child : children_) {
auto qualified_name = join_name(name_prefix, child.key());
function(qualified_name, child.value());
child.value()->apply_to_submodules(function, std::move(qualified_name));
child.value()->apply_to_submodules(function, qualified_name);
}
}

View File

@ -13,8 +13,7 @@ namespace torch {
namespace nn {
BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {}
BatchNormImpl::BatchNormImpl(BatchNormOptions options)
: options(std::move(options)) {
BatchNormImpl::BatchNormImpl(BatchNormOptions options) : options(options) {
reset();
}

View File

@ -14,7 +14,7 @@ EmbeddingOptions::EmbeddingOptions(int64_t count, int64_t dimension)
: count_(count), dimension_(dimension) {}
EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options)
: options(std::move(options)) {
: options(options) {
reset();
}

View File

@ -10,7 +10,7 @@ namespace torch {
namespace nn {
LinearOptions::LinearOptions(int64_t in, int64_t out) : in_(in), out_(out) {}
LinearImpl::LinearImpl(LinearOptions options) : options(std::move(options)) {
LinearImpl::LinearImpl(LinearOptions options) : options(options) {
reset();
}

View File

@ -131,7 +131,7 @@ RNNOutput RNNImplBase<Derived>::generic_forward(
}
Tensor output, new_state;
std::tie(output, new_state) = function(
std::move(input),
input,
std::move(state),
flat_weights_,
options.with_bias_,
@ -208,12 +208,12 @@ RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) {
case RNNActivation::ReLU:
return generic_forward(
static_cast<RNNFunctionSignature*>(&torch::rnn_relu),
std::move(input),
input,
std::move(state));
case RNNActivation::Tanh:
return generic_forward(
static_cast<RNNFunctionSignature*>(&torch::rnn_tanh),
std::move(input),
input,
std::move(state));
default:
AT_ERROR("Unhandled RNN activation function!");
@ -244,7 +244,7 @@ RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) {
}
Tensor output, hidden_state, cell_state;
std::tie(output, hidden_state, cell_state) = torch::lstm(
std::move(input),
input,
{state[0], state[1]},
flat_weights_,
options.with_bias_,
@ -266,9 +266,7 @@ GRUImpl::GRUImpl(const GRUOptions& options)
RNNOutput GRUImpl::forward(const Tensor& input, Tensor state) {
return generic_forward(
static_cast<RNNFunctionSignature*>(&torch::gru),
std::move(input),
std::move(state));
static_cast<RNNFunctionSignature*>(&torch::gru), input, std::move(state));
}
} // namespace nn
} // namespace torch

View File

@ -17,6 +17,7 @@ void serialize(
const std::string& key,
const std::vector<int64_t>& steps) {
std::vector<torch::Tensor> tensors;
tensors.reserve(steps.size());
for (const auto& step : steps) {
tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
}

View File

@ -153,7 +153,7 @@ inline Tensor as_variable(Tensor tensor) {
inline std::vector<Tensor> as_variable(TensorList tl) {
return fmap(tl, [](const Tensor& t) -> Tensor {
return make_variable(std::move(t), /*requires_grad=*/false);
return make_variable(t, /*requires_grad=*/false);
});
}

View File

@ -408,7 +408,7 @@ static variable_list call_function(FunctionTask& task) {
if(has_post_hooks){
// NOLINTNEXTLINE(bugprone-use-after-move)
return call_post_hooks(fn, std::move(outputs), std::move(inputs));
return call_post_hooks(fn, std::move(outputs), inputs);
}
return outputs;
}

View File

@ -585,10 +585,7 @@ static Node* _trace_pre_record(
Py_INCREF(op_obj);
auto pyobj = THPObjectPtr(op_obj);
return jit::tracer::preRecordPythonTrace(
std::move(pyobj),
std::move(arg_types),
input_vars,
std::move(scalar_args));
std::move(pyobj), arg_types, input_vars, std::move(scalar_args));
}
static void _trace_post_record(

View File

@ -31,7 +31,7 @@ BatchTensor::BatchTensor(const std::vector<at::Tensor>& datalist, at::Tensor dim
sizes[0] = bs;
mask_sizes[0] = bs;
for(int64_t i = 1; i < dims.size(0) + 1; i++){
for(auto x : datalist){
for(const auto& x : datalist){
sizes[i] = std::max(sizes[i], x.size(i));
}
mask_sizes[i] = *dims[i - 1].data<uint8_t>() ? sizes[i] : 1;

View File

@ -102,19 +102,19 @@ RegisterOperators reg({
return 0;
};
} else if(type->isSubtypeOf(ListType::ofInts())) {
auto is = node->is(attr::value);
const auto& is = node->is(attr::value);
return [is](Stack& stack) {
push(stack, is);
return 0;
};
} else if(type->isSubtypeOf(ListType::ofBools())) {
auto bs = node->is(attr::value);
const auto& bs = node->is(attr::value);
return [bs](Stack& stack) {
push(stack, bs);
return 0;
};
} else if(type->isSubtypeOf(ListType::ofTensors())) {
auto ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor {
const auto& ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor {
return autograd::make_variable(t);
});
return [ts](Stack& stack) {
@ -122,7 +122,7 @@ RegisterOperators reg({
return 0;
};
} else if (type == StringType::get()) {
auto s = node->s(attr::value);
const auto& s = node->s(attr::value);
return [s](Stack& stack) {
push(stack, s);
return 0;

View File

@ -63,7 +63,7 @@ Node* getTracedNode(
const std::tuple<Types...>& tuple) {
auto symbol = Symbol::fromQualString(schema.name());
const auto& graph = tracer::getTracingState()->graph;
Node* node = graph->create(std::move(symbol), /*num_outputs=*/0);
Node* node = graph->create(symbol, /*num_outputs=*/0);
tracer::recordSourceLocation(node);
// Hack to call addInputs for the parameter pack in a sequenced fashion.

View File

@ -51,7 +51,7 @@ int debugFuser() {
// If the given node is used once by a chunk node, returns that node.
// Returns nullptr otherwise.
static const Node* usedInFusedChunk(const Value* input) {
const auto uses = input->uses();
const auto& uses = input->uses();
if (uses.size() == 1) {
const Node *user = uses[0].user;
if (user->kind() == prim::ConstantChunk) {

View File

@ -503,7 +503,7 @@ private:
}
void runTraced(Stack & stack) {
auto state = tracer::getTracingState();
const auto& state = tracer::getTracingState();
auto inputs = last(stack, num_inputs);
auto input_values = fmap(inputs, [](const IValue & v) {
return tracer::getNestedValueTrace(v);

View File

@ -300,7 +300,7 @@ void initJITBindings(PyObject *module) {
m.def("_jit_get_operation", [](const std::string& qualified_name) {
try {
auto symbol = Symbol::fromQualString(qualified_name);
auto operations = getAllOperatorsFor(std::move(symbol));
auto operations = getAllOperatorsFor(symbol);
AT_CHECK(!operations.empty(), "No such operator ", qualified_name);
AT_CHECK(
operations.size() == 1,
@ -338,7 +338,7 @@ void initJITBindings(PyObject *module) {
});
m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) {
auto symbol = Symbol::fromQualString(qualified_name);
auto operations = getAllOperatorsFor(std::move(symbol));
auto operations = getAllOperatorsFor(symbol);
return fmap(operations, [](const std::shared_ptr<Operator>& op) {
return op->schema();
});

View File

@ -73,7 +73,7 @@ struct Suspend : public std::exception {
struct InterpreterContinuation {
InterpreterContinuation(InterpreterState state_, Stack stack_)
: state(std::move(state_)), stack(std::move(stack_)) {}
: state(state_), stack(std::move(stack_)) {}
void operator()() {
state.runAsync(stack);

View File

@ -253,7 +253,7 @@ struct SchemaParser {
n = "-" + L.expect(TK_NUMBER).text();
else
n = L.expect(TK_NUMBER).text();
if(kind == TypeKind::FloatType || n.find(".") != std::string::npos || n.find("e") != std::string::npos) {
if(kind == TypeKind::FloatType || n.find('.') != std::string::npos || n.find('e') != std::string::npos) {
return std::stod(n);
} else {
int64_t v = std::stoll(n);
@ -405,7 +405,7 @@ private:
// XXX - caller must be holding lock
void registerPendingOperators() {
for(auto op : to_register) {
for(const auto& op : to_register) {
Symbol sym = Symbol::fromQualString(op->schema().name());
operators[sym].push_back(op);
operators_by_sig[canonicalSchemaString(op->schema())] = op;

View File

@ -531,7 +531,7 @@ void AliasDb::addAlias(const Value* value, Symbol alias) {
valueToAlias_[value].addSet(alias);
} else {
AliasInfo aliasInfo;
aliasInfo.addSet(std::move(alias));
aliasInfo.addSet(alias);
valueToAlias_.insert({value, std::move(aliasInfo)});
}
}

View File

@ -973,6 +973,7 @@ void PeepholeOptimizeShapeExpressions(Block * block) {
}
if (unique_to_value.size() != node->inputs().size()) {
std::vector<Value*> inputs;
inputs.reserve(unique_to_value.size());
for (auto & entry : unique_to_value) {
inputs.push_back(entry.second);
}

View File

@ -637,7 +637,7 @@ struct PythonPrintPass {
} else if(v.isTensorList()) {
stmt << "[";
const char* delim = "";
for(auto t : v.toTensorListRef()) {
for(const auto& t : v.toTensorListRef()) {
stmt << delim << "CONSTANTS.c" << getOrAddTensorConstant(t);
delim = ", ";
}

View File

@ -15,7 +15,7 @@ void specializeUndef(Graph & g) {
std::unordered_map<Value*, State> state;
for (Value* input : g.inputs()) {
auto tp = input->type();
const auto& tp = input->type();
if (tp->isSubtypeOf(UndefinedTensorType::get())) {
state[input] = State::Undefined;
} else if (tp->isSubtypeOf(DynamicType::get())) {

View File

@ -467,7 +467,7 @@ void initPythonIRBindings(PyObject * module_) {
.def(py::init([](std::vector<TypePtr> a){ return TupleType::create(a); }))
.def("elements", [](TupleType &self){
std::vector<TypePtr> types;
for (auto type : self.elements()) {
for (const auto& type : self.elements()) {
types.push_back(type);
}
return types;

View File

@ -165,7 +165,7 @@ void initPythonTracerBindings(PyObject* module) {
return setValueTrace(var, value);
});
m.def("_tracer_set_get_unique_name_fn", [](py::function func) {
auto tracing_state = getTracingState();
const auto& tracing_state = getTracingState();
JIT_ASSERT(tracing_state);
tracing_state->lookup_var_name_fn = [func](const Variable& var) -> std::string {
AutoGIL ag;
@ -173,7 +173,7 @@ void initPythonTracerBindings(PyObject* module) {
};
});
m.def("_tracer_set_force_outplace", [](bool force_outplace) {
auto tracing_state = getTracingState();
const auto& tracing_state = getTracingState();
JIT_ASSERT(tracing_state);
tracing_state->force_outplace = force_outplace;
});

View File

@ -401,7 +401,7 @@ RegisterOperators reg({
return [=](Stack& stack) {
bool result = false;
for (const IValue& t : last(stack, num_inputs)) {
if (std::move(t).toTensor().defined()) {
if (t.toTensor().defined()) {
result = true;
break;
}
@ -1135,6 +1135,7 @@ Operator( \
at::Tensor t;
pop(stack, t);
std::vector<int64_t> elems;
elems.reserve(t.size(0));
for(int i = 0; i < t.size(0); i++){
elems.push_back(*t[i].data<int32_t>());
}

View File

@ -419,7 +419,7 @@ static inline bool isIntOrFloatUsedAsList(
const Value* value,
const Argument& arg) {
// Look for int[N] or float[N]
auto v_type = value->type();
const auto& v_type = value->type();
if (v_type != FloatType::get() && v_type != IntType::get())
return false;
auto arg_type = unwrapOptional(arg.type());
@ -1054,7 +1054,7 @@ private:
<< " return (" << schema.returns().size() << ") does not match"
<< " the number of returns from the function (" << results.size() << ")!";
}
auto range = return_stmt.range();
const auto& range = return_stmt.range();
size_t return_type_idx = 0;
for (auto r : results) {
TypePtr type = DynamicType::get();
@ -1506,7 +1506,7 @@ private:
auto instances = sv->asTuple(stmt.range(), method);
const std::string& target_name = target.name();
pushFrame(environment_stack->block());
for(auto inst : instances) {
for(const auto& inst : instances) {
environment_stack->setSugaredVar(itrs[0].range(), target_name, inst);
emitStatements(body);
}
@ -1988,7 +1988,7 @@ private:
if(maybe_unpack && tree->kind() == TK_STARRED) {
auto starred = Starred(tree);
auto entries = emitSugaredExpr(starred.expr(), 1)->asTuple(starred.range(), method);
for(auto entry : entries) {
for(const auto& entry : entries) {
values.emplace_back(
tree->range(), entry->asValue(starred.range(), method));
}
@ -2639,7 +2639,7 @@ void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::vector<D
auto resolver_it = resolvers.begin();
std::vector<Method*> methods;
std::unordered_map<std::string, Method*> function_table;
for(Def def : definitions) {
for(const Def& def : definitions) {
const std::string& name = def.name().name();
auto resolver = *resolver_it++;
JIT_ASSERT(resolver);

View File

@ -121,7 +121,7 @@ private:
struct TORCH_API BuiltinFunction : public SugaredValue {
BuiltinFunction(Symbol symbol, c10::optional<NamedValue> self)
: symbol(std::move(symbol)), self(std::move(self)) {}
: symbol(symbol), self(std::move(self)) {}
// The symbol of the function (e.g. `aten::relu`).
Symbol symbol;

View File

@ -185,7 +185,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
const SourceRange& loc,
Method& m) override {
std::vector<Value*> values;
for (auto sugared_item : asTuple(loc, m)) {
for (const auto& sugared_item : asTuple(loc, m)) {
values.push_back(sugared_item->asValue(loc, m));
}
auto node = m.graph()->createTuple(values);
@ -532,6 +532,7 @@ void initJitScriptBindings(PyObject* module) {
const std::vector<ResolutionCallback>& rcbs,
const std::vector<FunctionDefaults>& defaults) {
std::vector<Resolver> resolvers;
resolvers.reserve(rcbs.size());
for(auto & callback : rcbs) {
resolvers.push_back(pythonResolver(callback));
}

View File

@ -123,7 +123,7 @@ struct Method {
stack.push_back(*inp);
}
const auto size = stack.size();
setInputTypes(*retval, ArgumentSpec(with_grad, std::move(stack), size));
setInputTypes(*retval, ArgumentSpec(with_grad, stack, size));
PropagateInputShapes(retval);
return retval;
}

View File

@ -107,10 +107,10 @@ void initTreeViewBindings(PyObject *module) {
.def(py::init([](const Ident& name,
Decl decl,
std::vector<Stmt> body) {
auto r = name.range();
const auto& r = name.range();
return Def::create(r,
name,
std::move(decl),
decl,
wrap_list(r, std::move(body)));
}));
py::class_<Decl, TreeView>(m, "Decl")
@ -127,7 +127,7 @@ void initTreeViewBindings(PyObject *module) {
}));
py::class_<AugAssign, Stmt>(m, "AugAssign")
.def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) {
auto r = lhs.range();
const auto& r = lhs.range();
auto kind = AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
return AugAssign::create(r, lhs, kind, rhs);
}));
@ -198,13 +198,13 @@ void initTreeViewBindings(PyObject *module) {
}));
py::class_<Apply, Expr>(m, "Apply")
.def(py::init([](const Expr& expr, std::vector<Expr> args, std::vector<Attribute> kwargs) {
auto r = expr.range();
const auto& r = expr.range();
return Apply::create(expr.range(), expr,
wrap_list(r, std::move(args)), wrap_list(r, std::move(kwargs)));
}));
py::class_<Select, Expr>(m, "Select")
.def(py::init([](const Expr& expr, const Ident& field) {
auto r = expr.range();
const auto& r = expr.range();
return Select::create(expr.range(), expr, field);
}));
py::class_<TernaryIf, Expr>(m, "TernaryIf")

View File

@ -111,7 +111,7 @@ struct String : public Tree {
};
static SourceRange mergeRanges(SourceRange c, const TreeList& others) {
for (auto t : others) {
for (const auto& t : others) {
if (t->isAtom())
continue;
size_t s = std::min(c.start(), t->range().start());
@ -171,7 +171,7 @@ struct pretty_tree {
break;
default:
out << "(" << kindToString(t->kind());
for (auto e : t->trees()) {
for (const auto& e : t->trees()) {
out << " " << get_flat(e);
}
out << ")";
@ -188,7 +188,7 @@ struct pretty_tree {
}
std::string k = kindToString(t->kind());
out << "(" << k;
for (auto e : t->trees()) {
for (const auto& e : t->trees()) {
out << "\n" << std::string(indent + 2, ' ');
print(out, e, indent + 2);
}

View File

@ -290,7 +290,7 @@ static void py_bind_tensor_types(const std::vector<PyTensorType>& tensor_types)
for (auto& tensor_type : tensor_types) {
auto name = std::string(tensor_type.name);
auto idx = name.rfind(".");
auto idx = name.rfind('.');
auto type_name = name.substr(idx + 1);
auto module_name = name.substr(0, idx);

View File

@ -286,7 +286,7 @@ FunctionSignature::FunctionSignature(const std::string& fmt)
while (!done) {
auto offset = fmt.find(", ", last_offset);
if (offset == std::string::npos) {
offset = fmt.find(")", last_offset);
offset = fmt.find(')', last_offset);
done = true;
next_offset = offset + 1;
} else {

View File

@ -209,7 +209,7 @@ Tensor internal_new_from_data(
auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(torch::getDeviceType(type)));
AutoNoGIL no_gil;
maybe_initialize_cuda(device);
return var.to(device, scalar_type, /*blocking=*/false, /*copy=*/copy_variables);
return var.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/copy_variables);
}
#ifdef USE_NUMPY
@ -219,7 +219,7 @@ Tensor internal_new_from_data(
auto device = device_opt.has_value() ? *device_opt : at::Device(type.device_type());
AutoNoGIL no_gil;
maybe_initialize_cuda(device);
return tensor.to(device, scalar_type, /*blocking=*/false, /*copy=*/copy_numpy);
return tensor.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy);
}
#endif
@ -232,7 +232,7 @@ Tensor internal_new_from_data(
auto device = device_opt.has_value() ? *device_opt : at::Device(torch::getDeviceType(type));
AutoNoGIL no_gil;
maybe_initialize_cuda(device);
return tensor.to(device, scalar_type, /*blocking=*/false, /*copy=*/false);
return tensor.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/false);
}
Tensor new_from_data_copy(