Some performance fixes (#94034)

Applies some performance fixes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94034
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2023-02-04 02:17:45 +00:00
committed by PyTorch MergeBot
parent fa65ae8f56
commit 1a32db15e7
26 changed files with 62 additions and 65 deletions

View File

@ -201,7 +201,7 @@ inline std::vector<Tensor> cached_cast(
std::vector<Tensor> vec;
vec.reserve(arg.size());
for (const auto& t : arg) {
vec.push_back(cached_cast(to_type, t, device_type));
vec.emplace_back(cached_cast(to_type, t, device_type));
}
return vec;
}
@ -213,7 +213,7 @@ inline std::vector<Tensor> cached_cast(
std::vector<Tensor> vec;
vec.reserve(arg.size());
for (const auto& t : arg) {
vec.push_back(cached_cast(to_type, t, device_type));
vec.emplace_back(cached_cast(to_type, t, device_type));
}
return vec;
}

View File

@ -118,7 +118,7 @@ public:
return data_;
}
inline void set_data(mt19937_data_pod data) {
inline void set_data(const mt19937_data_pod& data) {
data_ = data;
}

View File

@ -524,9 +524,9 @@ void ClassType::checkNotExist(const std::string& name, const std::string& what)
}
void ClassType::addAttribute(ClassAttribute classAttribute) {
attributes_.push_back(classAttribute);
attributeTypes_.push_back(classAttribute.getType());
AT_ASSERT(attributes_.size() == attributeTypes_.size());
attributeTypes_.emplace_back(classAttribute.getType());
attributes_.emplace_back(std::move(classAttribute));
}
size_t ClassType::addAttribute(

View File

@ -879,7 +879,7 @@ IValue IValue::deepcopy(
case IValue::Tag::Tuple: {
std::vector<IValue> copied_tuple;
for (const auto& e : toTupleRef().elements()) {
copied_tuple.push_back(e.deepcopy(memo));
copied_tuple.emplace_back(e.deepcopy(memo));
}
copy = IValue(ivalue::Tuple::create(std::move(copied_tuple)));
}
@ -1067,11 +1067,11 @@ std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> ivalue::Future::extractSt
if (tensor.is_sparse()) {
// Sparse tensor is indices and values. Both are tensors
// and contain storage.
weakStorageImpls.push_back(tensor.indices().storage().getWeakStorageImpl());
weakStorageImpls.push_back(tensor.values().storage().getWeakStorageImpl());
weakStorageImpls.emplace_back(tensor.indices().storage().getWeakStorageImpl());
weakStorageImpls.emplace_back(tensor.values().storage().getWeakStorageImpl());
} else {
// A dense/strided tensor contains 1 storage
weakStorageImpls.push_back(tensor.storage().getWeakStorageImpl());
weakStorageImpls.emplace_back(tensor.storage().getWeakStorageImpl());
}
}
} else {
@ -1081,7 +1081,7 @@ std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> ivalue::Future::extractSt
value.getSubValues(sub_values);
for (const at::IValue& sub_value : sub_values) {
if (sub_value.isTensor()) {
weakStorageImpls.push_back(sub_value.toTensor().storage().getWeakStorageImpl());
weakStorageImpls.emplace_back(sub_value.toTensor().storage().getWeakStorageImpl());
}
}
}

View File

@ -162,7 +162,7 @@ void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten) {
"passed a `nullptr`");
std::vector<TypePtr> to_fill;
standardizeVectorForUnion(*to_flatten, &to_fill);
*to_flatten = to_fill;
*to_flatten = std::move(to_fill);
}
OptionalType::OptionalType(TypePtr contained)

View File

@ -65,7 +65,7 @@ struct DivMod {
// everything else, we use plain division.
template <typename Value>
struct IntDivider {
IntDivider() { } // Dummy constructor for arrays.
IntDivider() = default;
IntDivider(Value d) : divisor(d) { }
C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; }
@ -82,7 +82,7 @@ template <>
struct IntDivider<unsigned int> {
static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int.");
IntDivider() { } // Dummy constructor for arrays.
IntDivider() = default;
IntDivider(unsigned int d) : divisor(d) {
assert(divisor >= 1 && divisor <= INT32_MAX);

View File

@ -46,7 +46,7 @@ void NnapiCompilation::init(
void NnapiCompilation::init2(
at::Tensor serialized_model_tensor,
std::vector<at::Tensor> parameter_buffers,
const std::vector<at::Tensor>& parameter_buffers,
int64_t compilation_preference,
bool relax_f32_to_f16
) {
@ -55,7 +55,9 @@ void NnapiCompilation::init2(
load_platform_library();
std::vector<const void*> buffers;
buffers.reserve(parameter_buffers.size());
std::vector<int32_t> buffer_sizes;
buffer_sizes.reserve(parameter_buffers.size());
for (auto& t : parameter_buffers) {
TORCH_CHECK(t.is_contiguous());
buffers.push_back(t.data_ptr());
@ -75,8 +77,7 @@ void NnapiCompilation::init2(
};
TORCH_CHECK(!ser_model.empty());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
ANeuralNetworksModel* model;
ANeuralNetworksModel* model{};
check_nnapi->Model_create(&model);
CAFFE_ENFORCE(model);
model_.reset(model);
@ -102,8 +103,7 @@ void NnapiCompilation::init2(
}
check_nnapi->Model_finish(model_.get());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
ANeuralNetworksCompilation* compilation;
ANeuralNetworksCompilation* compilation{};
check_nnapi->Compilation_create(model_.get(), &compilation);
// TODO: Make this configurable.
check_nnapi->Compilation_setPreference(compilation, static_cast<int32_t>(compilation_preference));

View File

@ -44,7 +44,7 @@ struct NnapiCompilation : torch::jit::CustomClassHolder {
TORCH_API void init2(
at::Tensor serialized_model_tensor,
std::vector<at::Tensor> parameter_buffers,
const std::vector<at::Tensor>& parameter_buffers,
int64_t compilation_preference,
bool relax_f32_to_f16
);

View File

@ -208,7 +208,7 @@ class InlineMultiStreamGuard {
impl_.emplace(getDeviceTypeOfStreams(streams));
original_streams_.reserve(streams.size());
for (const Stream& s : streams) {
original_streams_.push_back(this->impl_->exchangeStream(s));
original_streams_.emplace_back(this->impl_->exchangeStream(s));
}
}
}

View File

@ -262,9 +262,8 @@ inline void free_impl(PtrInfo::iterator& it) {
if (C10_UNLIKELY(capture_underway)) {
// See Note [Avoid dangling free streams during CUDA graph capture]
capture_free_streams.insert(UsageStream(
dummy_unifying_free_stream.stream,
dummy_unifying_free_stream.device));
capture_free_streams.emplace(
dummy_unifying_free_stream.stream, dummy_unifying_free_stream.device);
}
}

View File

@ -138,10 +138,10 @@ struct KeyOrValueEquality : functor_storage<bool, key_equal> {
static constexpr int8_t min_lookups = 4;
template <typename T>
struct sherwood_v3_entry {
sherwood_v3_entry() {}
sherwood_v3_entry() = default;
sherwood_v3_entry(int8_t distance_from_desired)
: distance_from_desired(distance_from_desired) {}
~sherwood_v3_entry() {}
~sherwood_v3_entry() = default;
bool has_value() const {
return distance_from_desired >= 0;

View File

@ -114,7 +114,7 @@ private:
};
struct object : public handle {
object() {}
object() = default;
object(const object& other)
: handle(other.ptr_) {
Py_XINCREF(ptr_);
@ -160,7 +160,7 @@ protected:
template<typename T>
struct obj : public object {
obj() {}
obj() = default;
obj(const obj& other)
: object(other.ptr_) {
Py_XINCREF(ptr_);

View File

@ -179,7 +179,7 @@ static c10::intrusive_ptr<c10::StorageImpl> THPStorage_newFdStorage(
at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_UNLINK;
std::string handle = at::NewProcessWideShmHandle();
auto sptr = at::MapAllocator::makeDataPtr(
handle.c_str(), flags, size * sizeof(uint8_t), nullptr);
handle, flags, size * sizeof(uint8_t), nullptr);
return c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
size,

View File

@ -19,7 +19,7 @@ namespace utils {
// sense!) in order to return a CPU-side `double`. This C++ version therefore
// cannot be run fully asynchronously w.r.t. the device of the gradients.
inline double clip_grad_norm_(
std::vector<Tensor> parameters,
const std::vector<Tensor>& parameters,
double max_norm,
double norm_type = 2.0,
bool error_if_nonfinite = false) {
@ -118,7 +118,7 @@ inline double clip_grad_norm_(
// See https://pytorch.org/docs/stable/nn.html#clip-grad-value
// for more details about this module.
inline void clip_grad_value_(
std::vector<Tensor> parameters,
const std::vector<Tensor>& parameters,
double clip_value) {
for (const auto& param : parameters) {
if (param.grad().defined()) {

View File

@ -56,20 +56,18 @@ inline torch::Tensor parameters_to_vector(
// Convert one vector to the parameters
inline void vector_to_parameters(
const torch::Tensor& vec,
std::vector<torch::Tensor> parameters) {
const std::vector<torch::Tensor>& parameters) {
// Flag for the device where the parameter is located
c10::optional<int64_t> param_device;
// Pointer for slicing the vector for each parameter
int64_t pointer = 0;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t num_param;
for (torch::Tensor& param : parameters) {
for (const torch::Tensor& param : parameters) {
// Ensure the parameters are located in the same device
param_device = _check_param_device(param, param_device);
// The length of the parameter
num_param = param.numel();
auto num_param = param.numel();
// Slice the vector, reshape it, and replace the old data of the parameter
param.set_data(
vec.slice(0, pointer, pointer + num_param).view_as(param).data());

View File

@ -2779,7 +2779,7 @@ static inline c10::SymInt _min_storage_size(
// explanation
Tensor as_strided_backward(
Tensor grad,
TensorGeometry input_geometry,
const TensorGeometry& input_geometry,
c10::SymIntArrayRef sym_sizes,
c10::SymIntArrayRef sym_strides,
optional<c10::SymInt> sym_storage_offset_) {
@ -2908,7 +2908,7 @@ Tensor as_strided_backward(
Tensor as_strided_scatter_backward(
Tensor grad,
TensorGeometry input_geometry,
const TensorGeometry& input_geometry,
TensorGeometry src_geometry,
c10::SymIntArrayRef sizes,
c10::SymIntArrayRef strides,

View File

@ -715,13 +715,13 @@ Tensor gelu_double_backward(
c10::string_view approximate);
Tensor as_strided_backward(
Tensor grad,
TensorGeometry input_geometry,
const TensorGeometry& input_geometry,
c10::SymIntArrayRef sizes,
c10::SymIntArrayRef strides,
optional<c10::SymInt> storage_offset_);
Tensor as_strided_scatter_backward(
Tensor grad,
TensorGeometry input_geometry,
const TensorGeometry& input_geometry,
TensorGeometry src_geometry,
c10::SymIntArrayRef sizes,
c10::SymIntArrayRef strides,

View File

@ -276,7 +276,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
void add_next_edge(Edge edge) {
update_topological_nr(edge);
next_edges_.push_back(std::move(edge));
next_edges_.emplace_back(std::move(edge));
}
void set_next_edges(edge_list&& next_edges) {
@ -456,7 +456,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
post_hooks_.push_back(std::move(post_hook));
post_hooks_.emplace_back(std::move(post_hook));
// Use the raw pointer as the unique key to identify this hook. This key
// can then be used in del_post_hook(key) to remove this hook.
return reinterpret_cast<std::uintptr_t>(post_hooks_.back().get());
@ -483,11 +483,11 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
}
void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
pre_hooks_.push_back(std::move(pre_hook));
pre_hooks_.emplace_back(std::move(pre_hook));
}
void add_tensor_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
tensor_pre_hooks_.push_back(std::move(pre_hook));
tensor_pre_hooks_.emplace_back(std::move(pre_hook));
}
void add_retains_grad_hook(
@ -672,7 +672,7 @@ struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
void operator()(const Variable& variable) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (variable.defined()) {
next_edges.push_back(impl::gradient_edge(variable));
next_edges.emplace_back(impl::gradient_edge(variable));
} else {
next_edges.emplace_back();
}
@ -680,7 +680,7 @@ struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
void operator()(const Variable* variable) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (variable->defined()) {
next_edges.push_back(impl::gradient_edge(*variable));
next_edges.emplace_back(impl::gradient_edge(*variable));
} else {
next_edges.emplace_back();
}
@ -688,7 +688,7 @@ struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
void operator()(const c10::optional<Variable>& variable) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (variable.has_value() && variable->defined()) {
next_edges.push_back(impl::gradient_edge(*variable));
next_edges.emplace_back(impl::gradient_edge(*variable));
} else {
next_edges.emplace_back();
}

View File

@ -54,9 +54,9 @@ static inline std::vector<Tensor>& _broadcast_out_impl(
#ifdef USE_NCCL
std::vector<Tensor> nccl_list;
nccl_list.reserve(out_tensors.size() + 1);
nccl_list.push_back(tensor);
nccl_list.emplace_back(tensor);
for (auto& out_tensor : out_tensors) {
nccl_list.push_back(out_tensor);
nccl_list.emplace_back(out_tensor);
}
if (nccl::is_available(nccl_list)) {
nccl::broadcast(nccl_list);
@ -102,7 +102,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
TORCH_CHECK(
device >= 0, "Expected non-negative device index, but got ", device);
if (device != tensor.get_device()) {
diff_device_dst_tensors.push_back(at::empty(
diff_device_dst_tensors.emplace_back(at::empty(
tensor.sizes(),
tensor.options().device(
at::Device(DeviceType::CUDA, device)))); // preserve memory format
@ -116,9 +116,9 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
for (auto device : devices) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (device != tensor.get_device()) {
dst_tensors.push_back(*it++);
dst_tensors.emplace_back(*it++);
} else {
dst_tensors.push_back(tensor);
dst_tensors.emplace_back(tensor);
}
}
TORCH_INTERNAL_ASSERT(it == diff_device_dst_tensors.end());
@ -197,7 +197,7 @@ tensor_list2d broadcast_coalesced(
for (const auto& var : torch::utils::unflatten_sparse_tensors(
inds, vals, chunk.tensors)) {
// See NOTE [ Version Counter in comm.*_coalesced ]
device_outputs.push_back(make_variable(var.tensor_data(), false));
device_outputs.emplace_back(make_variable(var.tensor_data(), false));
}
}
} else {
@ -209,7 +209,7 @@ tensor_list2d broadcast_coalesced(
for (auto& var :
torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
// See NOTE [ Version Counter in comm.*_coalesced ]
device_outputs.push_back(make_variable(var.tensor_data(), false));
device_outputs.emplace_back(make_variable(var.tensor_data(), false));
}
}
}
@ -255,7 +255,7 @@ std::vector<at::Tensor>& scatter_out(
bool same_ndim = out_sizes.size() == tensor.dim();
if (same_ndim) {
total_size += out_sizes[dim];
chunk_sizes.push_back(out_sizes[dim]);
chunk_sizes.emplace_back(out_sizes[dim]);
out_sizes[dim] = tensor.size(dim);
}
TORCH_CHECK(
@ -379,7 +379,7 @@ static inline at::Tensor& _gather_out_impl(
std::vector<int64_t> chunk_sizes;
chunk_sizes.reserve(tensors.size());
for (auto& tensor : tensors) {
chunk_sizes.push_back(tensor.size(dim));
chunk_sizes.emplace_back(tensor.size(dim));
}
auto chunks =
out_tensor.split_with_sizes(/*split_sizes=*/chunk_sizes, /*dim=*/dim);

View File

@ -41,13 +41,13 @@ std::unique_ptr<LoweringContext> LoweringContext::Create(
c10::ArrayRef<const Node*> post_order,
Util::EmissionMap emit_status) {
return getBackend()->CreateLoweringContext(
name, device, post_order, emit_status);
name, std::move(device), post_order, emit_status);
}
std::unique_ptr<LoweringContext> LoweringContext::Create(
const std::string& name,
BackendDevice device) {
return getBackend()->CreateLoweringContext(name, device);
return getBackend()->CreateLoweringContext(name, std::move(device));
}
} // namespace lazy

View File

@ -367,7 +367,7 @@ std::vector<LazyTensorPtr> GetLtcTensors(c10::ArrayRef<at::Tensor> tensors) {
std::vector<LazyTensorPtr> ltc_tensors;
ltc_tensors.reserve(tensors.size());
for (const auto& tensor : tensors) {
ltc_tensors.push_back(TryGetLtcTensor(tensor));
ltc_tensors.emplace_back(TryGetLtcTensor(tensor));
}
return ltc_tensors;
}

View File

@ -166,7 +166,7 @@ void initLazyBindings(PyObject* module) {
std::vector<LazyTensorPtr> xtensors;
xtensors.reserve(tensors.size());
for (auto& tensor : tensors) {
xtensors.push_back(TryGetLtcTensor(tensor));
xtensors.emplace_back(TryGetLtcTensor(tensor));
}
auto hash = LazyGraphExecutor::Get()->GetGraphHash(xtensors);
std::string bin((const char*)&hash, sizeof(hash));

View File

@ -437,7 +437,7 @@ void initDispatchBindings(PyObject* module) {
std::vector<std::string> states;
states.reserve(danglingImpls.size());
for (auto& danglingImpl : danglingImpls) {
states.push_back(danglingImpl.dumpState());
states.emplace_back(danglingImpl.dumpState());
}
return states;
@ -454,7 +454,7 @@ void initDispatchBindings(PyObject* module) {
if (!op.overload_name.empty()) {
ss << "." << op.overload_name;
}
names.push_back(ss.str());
names.emplace_back(ss.str());
}
return names;
@ -613,7 +613,7 @@ void initDispatchBindings(PyObject* module) {
std::vector<std::string> names;
names.reserve(op_names.size());
for (auto& op : op_names) {
names.push_back(
names.emplace_back(
op.name +
(op.overload_name.empty() ? "" : "." + op.overload_name));
}

View File

@ -261,7 +261,7 @@ std::vector<c10::FunctionSchema> SchemaInfo::getNonDeterministicOps() {
std::vector<c10::FunctionSchema> nondeterministic_ops;
nondeterministic_ops.reserve(nondeterministic_op_strings.size());
for (const std::string& signature : nondeterministic_op_strings) {
nondeterministic_ops.push_back(torch::jit::parseSchema(signature));
nondeterministic_ops.emplace_back(torch::jit::parseSchema(signature));
}
return nondeterministic_ops;

View File

@ -17,7 +17,7 @@ using SchemaSpecialCasePair =
struct TORCH_API SchemaInfo {
public:
explicit SchemaInfo(const c10::FunctionSchema& schema)
explicit SchemaInfo(c10::FunctionSchema schema)
: schema_(std::move(schema)),
alias_maps_current_(false),
has_init_(false) {}

View File

@ -26,7 +26,7 @@ struct StashTorchDispatchModeGuard {
struct StashTorchDispatchStackGuard {
public:
StashTorchDispatchStackGuard() {
const auto old = c10::impl::TorchDispatchModeTLS::get_state();
auto old = c10::impl::TorchDispatchModeTLS::get_state();
c10::impl::TorchDispatchModeTLS::set_state(saved_state_);
saved_state_ = std::move(old);
}