replace torch::make_unique with std::make_unique (#108866)

It should be safe to remove the old torch::make_unique functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108866
Approved by: https://github.com/albanD
This commit is contained in:
cyy
2023-09-14 20:52:21 +00:00
committed by PyTorch MergeBot
parent f03b8abd47
commit 03e35efbf7
54 changed files with 53 additions and 134 deletions

View File

@ -167,7 +167,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
assetName->toStdString().c_str());
}
JITCallGuard guard;
module_ = torch::jit::load(torch::make_unique<MemoryReadAdapter>(
module_ = torch::jit::load(std::make_unique<MemoryReadAdapter>(
assetBuffer, AAsset_getLength(asset)));
AAsset_close(asset);
module_.eval();

View File

@ -132,7 +132,7 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
}
LiteJITCallGuard guard;
module_ =
torch::jit::_load_for_mobile(torch::make_unique<MemoryReadAdapter>(
torch::jit::_load_for_mobile(std::make_unique<MemoryReadAdapter>(
assetBuffer, AAsset_getLength(asset)));
AAsset_close(asset);
deviceType_ = deviceJniCodeToDeviceType(device);

View File

@ -71,7 +71,6 @@ def libtorch_generated_sources(gencode_pattern):
# copied from https://github.com/pytorch/pytorch/blob/f99a693cd9ff7a9b5fdc71357dac66b8192786d3/aten/src/ATen/core/CMakeLists.txt
jit_core_headers = [
"torch/csrc/utils/memory.h",
"torch/csrc/Export.h",
"torch/csrc/jit/frontend/source_range.h",
"torch/csrc/jit/serialization/callstack_debug_info_serialization.h",

View File

@ -872,7 +872,7 @@ TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
for (const auto i : c10::irange(num_replicas)) {
samplers.emplace_back(
torch::make_unique<samplers::DistributedRandomSampler>(
std::make_unique<samplers::DistributedRandomSampler>(
sample_count, num_replicas, i, allow_duplicates));
}
@ -969,7 +969,7 @@ TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
for (const auto i : c10::irange(num_replicas)) {
samplers.emplace_back(
torch::make_unique<samplers::DistributedSequentialSampler>(
std::make_unique<samplers::DistributedSequentialSampler>(
sample_count, num_replicas, i, allow_duplicates));
}

View File

@ -1,7 +1,5 @@
#include <gtest/gtest.h>
#include <torch/csrc/utils/memory.h>
#include <c10/util/Optional.h>
struct TestValue {
@ -13,7 +11,7 @@ struct TestValue {
};
TEST(MakeUniqueTest, ForwardRvaluesCorrectly) {
auto ptr = torch::make_unique<TestValue>(123);
auto ptr = std::make_unique<TestValue>(123);
ASSERT_FALSE(ptr->lvalue_.has_value());
ASSERT_TRUE(ptr->rvalue_.has_value());
ASSERT_EQ(*ptr->rvalue_, 123);
@ -21,7 +19,7 @@ TEST(MakeUniqueTest, ForwardRvaluesCorrectly) {
TEST(MakeUniqueTest, ForwardLvaluesCorrectly) {
int x = 5;
auto ptr = torch::make_unique<TestValue>(x);
auto ptr = std::make_unique<TestValue>(x);
ASSERT_TRUE(ptr->lvalue_.has_value());
ASSERT_EQ(*ptr->lvalue_, 5);
ASSERT_FALSE(ptr->rvalue_.has_value());
@ -29,7 +27,7 @@ TEST(MakeUniqueTest, ForwardLvaluesCorrectly) {
TEST(MakeUniqueTest, CanConstructUniquePtrOfArray) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
auto ptr = torch::make_unique<int[]>(3);
auto ptr = std::make_unique<int[]>(3);
// Value initialization is required by the standard.
ASSERT_EQ(ptr[0], 0);
ASSERT_EQ(ptr[1], 0);

View File

@ -7,7 +7,6 @@
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/TensorOperators.h>
@ -23,7 +22,7 @@ class TopologicalMoveTest : public ::testing::Test {
protected:
TopologicalMoveTest() {
createGraph();
aliasDb = torch::make_unique<AliasDb>(graph);
aliasDb = std::make_unique<AliasDb>(graph);
}
// Nodes are named after their output.

View File

@ -3007,7 +3007,7 @@ graph(%x.1 : Tensor):
return (%y))IR",
&*graph);
{
auto func = torch::make_unique<GraphFunction>(
auto func = std::make_unique<GraphFunction>(
"name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::PROFILING);
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
Stack stack = {a};
@ -3020,7 +3020,7 @@ graph(%x.1 : Tensor):
->run(*g);
}
{
auto func = torch::make_unique<GraphFunction>(
auto func = std::make_unique<GraphFunction>(
"name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::SIMPLE);
auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
Stack stack = {a};

View File

@ -1147,7 +1147,7 @@ struct CudaGraphFuser {
}
void refreshAliasDb() {
aliasDb_ = torch::make_unique<AliasDb>(graph_);
aliasDb_ = std::make_unique<AliasDb>(graph_);
}
void removeNoopBinaryOps(Block* block) {

View File

@ -3,7 +3,6 @@
#include <torch/data/dataloader/stateful.h>
#include <torch/data/dataloader/stateless.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>
#include <c10/util/Exception.h>
@ -23,7 +22,7 @@ torch::disable_if_t<
Dataset::is_stateful,
std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) {
return torch::make_unique<StatelessDataLoader<Dataset, Sampler>>(
return std::make_unique<StatelessDataLoader<Dataset, Sampler>>(
std::move(dataset), std::move(sampler), std::move(options));
}
@ -51,7 +50,7 @@ template <typename Dataset, typename = torch::enable_if_t<Dataset::is_stateful>>
std::unique_ptr<StatefulDataLoader<Dataset>> make_data_loader(
Dataset dataset,
DataLoaderOptions options = DataLoaderOptions()) {
return torch::make_unique<StatefulDataLoader<Dataset>>(
return std::make_unique<StatefulDataLoader<Dataset>>(
std::move(dataset), std::move(options));
}
} // namespace data

View File

@ -8,7 +8,6 @@
#include <torch/data/worker_exception.h>
#include <torch/types.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>
#include <c10/util/Exception.h>
@ -62,15 +61,14 @@ class DataLoaderBase {
"Attempted to get a new DataLoader iterator "
"while another iterator is not yet exhausted");
reset();
return Iterator<Batch>(torch::make_unique<detail::ValidIterator<Batch>>(
return Iterator<Batch>(std::make_unique<detail::ValidIterator<Batch>>(
[this] { return this->next(); }));
}
/// Returns a special "sentinel" iterator that compares equal with a
/// non-sentinel iterator once the DataLoader is exhausted.
Iterator<Batch> end() {
return Iterator<Batch>(
torch::make_unique<detail::SentinelIterator<Batch>>());
return Iterator<Batch>(std::make_unique<detail::SentinelIterator<Batch>>());
}
/// Joins the DataLoader's worker threads and drains internal queues.
@ -215,10 +213,10 @@ class DataLoaderBase {
/// `enforce_ordering` option.
std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
if (options_.enforce_ordering) {
return torch::make_unique<detail::sequencers::OrderedSequencer<Result>>(
return std::make_unique<detail::sequencers::OrderedSequencer<Result>>(
options_.max_jobs);
}
return torch::make_unique<detail::sequencers::NoSequencer<Result>>();
return std::make_unique<detail::sequencers::NoSequencer<Result>>();
}
/// The options the DataLoader was configured with.

View File

@ -38,7 +38,7 @@ class StatefulDataLoader : public DataLoaderBase<
StatefulDataLoader(Dataset dataset, DataLoaderOptions options)
: super(
std::move(options),
torch::make_unique<Dataset>(std::move(dataset))) {
std::make_unique<Dataset>(std::move(dataset))) {
for (const auto w : c10::irange(this->options_.workers)) {
// As opposed to the stateless case, here all worker threads access the
// same underlying dataset.

View File

@ -3,8 +3,6 @@
#include <torch/data/dataloader/base.h>
#include <torch/data/worker_exception.h>
#include <torch/csrc/utils/memory.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
@ -52,7 +50,7 @@ class StatelessDataLoader : public DataLoaderBase<
}
if (this->options_.workers == 0) {
this->main_thread_dataset_ =
torch::make_unique<Dataset>(std::move(dataset));
std::make_unique<Dataset>(std::move(dataset));
}
}

View File

@ -2,7 +2,6 @@
#include <c10/util/irange.h>
#include <torch/arg.h>
#include <torch/csrc/utils/memory.h>
#include <torch/data/datasets/stateful.h>
#include <torch/data/samplers.h>
#include <queue>
@ -391,7 +390,7 @@ class ChunkDataset final
// Throw out any existing cached batch in the buffer and re-creates a new
// chunk buffer.
batch_buffer_ = torch::make_unique<
batch_buffer_ = std::make_unique<
detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>(
options_.batch_size(), example_sampler_, options_.cache_size());

View File

@ -8,7 +8,6 @@
#include <torch/types.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>
#include <ATen/Device.h>
@ -340,7 +339,7 @@ std::unique_ptr<AnyModulePlaceholder> AnyModule::make_holder(
!std::is_void<ReturnType>::value,
"AnyModule cannot store modules that return void "
"(you can return a dummy value).");
return torch::make_unique<
return std::make_unique<
AnyModuleHolder<decay_t<ModuleType>, ArgumentTypes...>>(
std::move(module));
}

View File

@ -116,12 +116,12 @@ struct AnyModuleHolder : public AnyModulePlaceholder {
}
std::unique_ptr<AnyModulePlaceholder> copy() const override {
return torch::make_unique<AnyModuleHolder>(*this);
return std::make_unique<AnyModuleHolder>(*this);
}
std::unique_ptr<AnyModulePlaceholder> clone_module(
optional<Device> device) const override {
return torch::make_unique<AnyModuleHolder>(
return std::make_unique<AnyModuleHolder>(
std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
}

View File

@ -6,7 +6,6 @@
#include <torch/types.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>
#include <memory>
@ -41,8 +40,8 @@ class AnyValue {
template <typename T>
// NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
explicit AnyValue(T&& value)
: content_(
torch::make_unique<Holder<decay_t<T>>>(std::forward<T>(value))) {}
: content_(std::make_unique<Holder<decay_t<T>>>(std::forward<T>(value))) {
}
/// Returns a pointer to the value contained in the `AnyValue` if the type
/// passed as template parameter matches the type of the value stored, and
@ -110,7 +109,7 @@ class AnyValue {
explicit Holder(U&& value_) noexcept
: Placeholder(typeid(T)), value(std::forward<U>(value_)) {}
std::unique_ptr<Placeholder> clone() const override {
return torch::make_unique<Holder<T>>(value);
return std::make_unique<Holder<T>>(value);
}
T value;
};

View File

@ -7,7 +7,6 @@
#include <torch/types.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/variadic.h>
#include <ATen/Device.h>

View File

@ -5,7 +5,6 @@
#include <c10/util/irange.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/utils/memory.h>
#include <torch/library.h>
using namespace at;

View File

@ -9,7 +9,6 @@
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/generated/VariableType.h>
#include <torch/csrc/utils/memory.h>
#include <torch/library.h>
#include <utility>

View File

@ -7,7 +7,6 @@
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/dynamo/compiled_autograd.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/DeviceGuard.h>
#include <ATen/ExpandUtils.h>
@ -1586,7 +1585,7 @@ void GraphTask::init_to_execute(
// In terms of populating the rest of exec_info though, you can basically
// think of this as the same as setting `needed_` is true directly.
if (!info.captures_) {
info.captures_ = make_unique<std::vector<ExecInfo::Capture>>();
info.captures_ = std::make_unique<std::vector<ExecInfo::Capture>>();
}
info.captures_->emplace_back(output_edge.input_nr, output_idx++);
}

View File

@ -20,7 +20,6 @@
#include <torch/csrc/autograd/utils/lambda_post_hook.h>
#include <torch/csrc/distributed/c10d/comm.hpp>
#include <torch/csrc/distributed/c10d/logger.hpp>
#include <torch/csrc/utils/memory.h>
namespace c10d {
namespace {
@ -185,7 +184,7 @@ Reducer::Reducer(
// Hook to execute after the gradient accumulator has executed.
hooks_.emplace_back(
grad_accumulator->add_post_hook(
torch::make_unique<torch::autograd::utils::LambdaPostHook>(
std::make_unique<torch::autograd::utils::LambdaPostHook>(
[=](const torch::autograd::variable_list& outputs,
const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32

View File

@ -8,7 +8,6 @@
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/qualified_name.h>
@ -132,7 +131,7 @@ struct TORCH_API CompilationUnit {
if (shouldMangle) {
name = mangle(name);
}
auto fn = torch::make_unique<GraphFunction>(
auto fn = std::make_unique<GraphFunction>(
std::move(name), std::move(graph), nullptr);
auto ret = fn.get();
register_function(std::move(fn));

View File

@ -3,7 +3,6 @@
#include <ATen/core/function.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {
@ -77,7 +76,7 @@ struct TORCH_API GraphFunction : public Function {
}
Function& setSchema(FunctionSchema schema) override {
schema_ = make_unique<FunctionSchema>(std::move(schema));
schema_ = std::make_unique<FunctionSchema>(std::move(schema));
return *this;
}

View File

@ -11,7 +11,6 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/api/include/torch/ordered_dict.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/utils/memory.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/qualified_name.h>

View File

@ -6,7 +6,6 @@
#include <c10/util/Optional.h>
#include <torch/csrc/jit/codegen/fuser/compiler.h>
#include <torch/csrc/jit/codegen/fuser/cpu/temp_file.h>
#include <torch/csrc/utils/memory.h>
#include <cstdlib>
#include <iostream>
@ -333,7 +332,7 @@ FusedKernelCPU::FusedKernelCPU(
runCompiler(cpp_file.name(), so_file.name());
if (debugFuser() >= 2)
disas(so_file.name());
so_lib = make_unique<at::DynamicLibrary>(so_file.name().c_str());
so_lib = std::make_unique<at::DynamicLibrary>(so_file.name().c_str());
#pragma GCC diagnostic ignored "-Wpedantic"
kernel =
reinterpret_cast<void (*)(uint32_t, void**)>(so_lib->sym(name_.c_str()));

View File

@ -2,7 +2,6 @@
#include <c10/util/Optional.h>
#include <torch/csrc/jit/frontend/tree.h>
#include <torch/csrc/utils/memory.h>
namespace torch::jit {

View File

@ -5432,7 +5432,7 @@ std::unique_ptr<Function> CompilationUnit::define(
auto graph = std::make_shared<Graph>();
graph->set_op_version(operator_set_version);
auto fn = torch::make_unique<GraphFunction>(std::move(name), graph, creator);
auto fn = std::make_unique<GraphFunction>(std::move(name), graph, creator);
if (self) {
// Register this as a method on `self`'s type
if (type == CompilationUnit::FunctionType::Hook) {

View File

@ -8,7 +8,6 @@
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/utils/memory.h>
#include <fstream>
namespace torch::jit {

View File

@ -525,7 +525,7 @@ mobile::Module _load_for_mobile_impl(
}
const size_t model_size = rai != nullptr ? rai->size() : 0;
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
if (module_load_options &
MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS) {
// ExtraFilesMap is serialized with a "extra/", hence it is necessary to
@ -694,7 +694,7 @@ void _load_extra_only_for_mobile(
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
BytecodeDeserializer deserializer(std::move(reader));
deserializer.deserialize_only_extra(device, extra_files);
break;

View File

@ -170,7 +170,7 @@ c10::IValue IValueUnpickler::readArchive(
std::map<std::string, at::Tensor> load_parameters_from_zip(
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device) {
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
IValueUnpickler unpickler(std::move(reader));
auto result = unpickler.deserialize(device).toGenericDict();
std::map<std::string, at::Tensor> map;

View File

@ -13,7 +13,6 @@
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#include <torch/csrc/utils/memory.h>
#include <utility>

View File

@ -4,7 +4,6 @@
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/utils/memory.h>
#include <cstddef>
#include <limits>
@ -28,7 +27,7 @@ struct FunctionalGraphSlicer {
// subgraphs, invalidating the AliasDb, so we need to do our analysis
// first.
for (size_t i = 0; i < MAX_NUM_ITERATIONS && changed; ++i) {
aliasDb_ = torch::make_unique<AliasDb>(graph_);
aliasDb_ = std::make_unique<AliasDb>(graph_);
AnalyzeFunctionalSubset(graph_->block());
changed = CreateFunctionalGraphsImpl(graph_->block());
}

View File

@ -4,7 +4,6 @@
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/utils/memory.h>
#include <unordered_map>

View File

@ -344,7 +344,7 @@ class AttributePropagator {
void recordMutableAttrs(std::shared_ptr<Graph>& graph) {
std::stack<Block*> blocks({graph->block()});
std::unique_ptr<AliasDb> aliasDb =
torch::make_unique<AliasDb>(graph, /* isFrozen */ true);
std::make_unique<AliasDb>(graph, /* isFrozen */ true);
while (!blocks.empty()) {
Block* block = blocks.top();
blocks.pop();

View File

@ -9,7 +9,6 @@
#include <torch/csrc/jit/passes/remove_dropout.h>
#include <torch/csrc/jit/passes/utils/optimization_utils.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>

View File

@ -7,7 +7,6 @@
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {

View File

@ -110,7 +110,7 @@ void InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph) {
// CALCULATE ALIASING SETS
auto aliasDb = torch::make_unique<AliasDb>(graph);
auto aliasDb = std::make_unique<AliasDb>(graph);
// map from Value to its Aliasing Set
std::unordered_map<Value*, ValueSetPtr> alias_mapping;

View File

@ -3,7 +3,6 @@
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/integer_value_refinement.h>
#include <torch/csrc/jit/passes/value_refinement_utils.h>
#include <torch/csrc/utils/memory.h>
#include <utility>

View File

@ -12,7 +12,6 @@
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
#include <torch/csrc/jit/passes/peephole_non_tensor.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {

View File

@ -6,7 +6,6 @@
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/peephole_alias_sensitive.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>
#include <unordered_set>
namespace torch {
@ -20,7 +19,7 @@ struct PeepholeOptimizeAliasSensitiveImpl {
std::shared_ptr<Graph> graph,
bool shape_peepholes)
: graph_(std::move(graph)),
aliasDb_(torch::make_unique<AliasDb>(graph_)),
aliasDb_(std::make_unique<AliasDb>(graph_)),
shape_peepholes_(shape_peepholes) {}
bool run() {

View File

@ -8,7 +8,6 @@
#include <torch/csrc/jit/passes/value_refinement_utils.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
#include <torch/csrc/utils/memory.h>
#include <limits>
#include <utility>
@ -161,7 +160,7 @@ struct PeepholeOptimizeListIdiomsImpl {
std::shared_ptr<Graph> graph,
bool refine_list_len)
: graph_(std::move(graph)),
aliasDb_(torch::make_unique<AliasDb>(graph_)),
aliasDb_(std::make_unique<AliasDb>(graph_)),
refine_list_len_(refine_list_len) {}
bool run() {

View File

@ -971,7 +971,7 @@ std::unique_ptr<GraphFunction> SubGraphCloneHelper::buildGraphFromNodes(
auto build_observer_graph = [&](GraphFunction& func) {
buildObserverSubgraph(nodes, func.graph());
};
return torch::make_unique<GraphFunction>(
return std::make_unique<GraphFunction>(
name, observer_subgraph, build_observer_graph);
}

View File

@ -4,7 +4,6 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {

View File

@ -4,7 +4,6 @@
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {

View File

@ -5,7 +5,6 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {

View File

@ -23,7 +23,6 @@
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/exception_message.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <torch/csrc/utils/memory.h>
#include <algorithm>
#include <memory>
#include <numeric>

View File

@ -23,7 +23,6 @@
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/utils/memory.h>
#include <utility>
@ -550,7 +549,7 @@ class TensorExprFuser {
}
void run() {
aliasDb_ = torch::make_unique<AliasDb>(graph_);
aliasDb_ = std::make_unique<AliasDb>(graph_);
RemoveRedundantProfiles(graph_);
GRAPH_DUMP("After removing redundant profile nodes: ", graph_);
createFusionGroups(graph_->block());

View File

@ -1,7 +1,6 @@
#include <torch/csrc/jit/passes/utils/memory_dag.h>
#include <c10/util/flat_hash_map.h>
#include <torch/csrc/utils/memory.h>
#include <algorithm>
#include <queue>

View File

@ -8,7 +8,6 @@
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {

View File

@ -562,7 +562,7 @@ struct CodeImpl {
};
auto empty_graph = std::make_shared<Graph>();
auto func = torch::make_unique<GraphFunction>(
auto func = std::make_unique<GraphFunction>(
"bailout", empty_graph, build_bailout_graph);
function_table_.emplace_back(func.get());
bailout_functions_.emplace_back(std::move(func));

View File

@ -32,7 +32,7 @@ void fuseStaticSubgraphs(std::shared_ptr<Graph> graph, size_t min_size) {
RemoveTensorMutation(graph);
ConstantPropagation(graph);
EliminateDeadCode(graph);
auto aliasDb = torch::make_unique<AliasDb>(graph);
auto aliasDb = std::make_unique<AliasDb>(graph);
createFusionGroups(graph->block(), aliasDb.get(), min_size);
ConstantPooling(graph);
ConstantPropagation(graph);

View File

@ -383,7 +383,7 @@ Module import_ir_module(
// NOTE: Zipformat can be large files. So using stream version directly
// instead of reading the file all at once.
if (getFileFormat(in) != FileFormat::FlatbufferFileFormat) {
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
auto reader = std::make_unique<PyTorchStreamReader>(&in);
reader->setShouldLoadDebugSymbol(load_debug_files);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files, restore_shapes);
@ -432,7 +432,7 @@ Module import_ir_module(
// NOTE: Zipformat can be large files. So using stream version directly
// instead of reading the file all at once.
if (getFileFormat(filename) != FileFormat::FlatbufferFileFormat) {
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
auto reader = std::make_unique<PyTorchStreamReader>(filename);
reader->setShouldLoadDebugSymbol(load_debug_files);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files, restore_shapes);
@ -548,7 +548,7 @@ Module _load_jit_module_from_bytes(
}
case FileFormat::ZipFileFormat: {
auto rai = std::make_unique<MemoryReadAdapter>(data.get(), size);
auto reader = torch::make_unique<PyTorchStreamReader>(std::move(rai));
auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files, restore_shapes);
}

View File

@ -1,6 +1,5 @@
#include <torch/csrc/utils/invalid_arguments.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/utils/python_strings.h>
#include <c10/util/irange.h>
@ -136,25 +135,25 @@ std::vector<std::string> _splitString(
std::unique_ptr<Type> _buildType(std::string type_name, bool is_nullable) {
std::unique_ptr<Type> result;
if (type_name == "float") {
result = torch::make_unique<MultiType>(MultiType{"float", "int", "long"});
result = std::make_unique<MultiType>(MultiType{"float", "int", "long"});
} else if (type_name == "int") {
result = torch::make_unique<MultiType>(MultiType{"int", "long"});
result = std::make_unique<MultiType>(MultiType{"int", "long"});
} else if (type_name.find("tuple[") == 0) {
auto type_list = type_name.substr(6);
type_list.pop_back();
std::vector<std::unique_ptr<Type>> types;
for (auto& type : _splitString(type_list, ","))
types.emplace_back(_buildType(type, false));
result = torch::make_unique<TupleType>(std::move(types));
result = std::make_unique<TupleType>(std::move(types));
} else if (type_name.find("sequence[") == 0) {
auto subtype = type_name.substr(9);
subtype.pop_back();
result = torch::make_unique<SequenceType>(_buildType(subtype, false));
result = std::make_unique<SequenceType>(_buildType(subtype, false));
} else {
result = torch::make_unique<SimpleType>(type_name);
result = std::make_unique<SimpleType>(type_name);
}
if (is_nullable)
result = torch::make_unique<NullableType>(std::move(result));
result = std::make_unique<NullableType>(std::move(result));
return result;
}

View File

@ -1,41 +0,0 @@
#pragma once
#include <memory>
namespace torch {
// Reference:
// https://github.com/llvm-mirror/libcxx/blob/master/include/memory#L3091
template <typename T>
struct unique_type_for {
using value = std::unique_ptr<T>;
};
template <typename T>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
struct unique_type_for<T[]> {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
using unbounded_array = std::unique_ptr<T[]>;
};
template <typename T, size_t N>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
struct unique_type_for<T[N]> {
using bounded_array = void;
};
template <typename T, typename... Args>
typename unique_type_for<T>::value make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
template <typename T>
typename unique_type_for<T>::unbounded_array make_unique(size_t size) {
using U = typename std::remove_extent<T>::type;
return std::unique_ptr<T>(new U[size]());
}
template <typename T, size_t N, typename... Args>
typename unique_type_for<T>::bounded_array make_unique(Args&&...) = delete;
} // namespace torch