Removing some dependency edges from Blob to other caffe2 (#12043)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12043

Re-trying D9979976, this time with all call sites fixed.

D9979976 got reverted because there was a call site that wasn't covered by sandcastle it seems.
I fixed it and used 'grep' to ensure there aren't any more call sites in fbsource.

Reviewed By: ezyang

Differential Revision: D10026392

fbshipit-source-id: cd341514a8e53a40147ea0ee3e52f63bb6444157
This commit is contained in:
Sebastian Messmer
2018-09-25 11:26:48 -07:00
committed by Facebook Github Bot
parent 94c513cc7f
commit 8f0db9bbbb
66 changed files with 380 additions and 371 deletions

View File

@ -163,7 +163,7 @@ void loadInput(
CAFFE_THROW("Not support GPU on mobile.");
#endif
} else {
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
if (input_type_list[i] == "uint8_t") {
@ -200,7 +200,7 @@ void fillInputBlob(
int protos_size = tensor_kv.second.protos_size();
caffe2::TensorProto* tensor_proto =
tensor_kv.second.mutable_protos(iteration % protos_size);
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
if (tensor_proto->data_type() == caffe2::TensorProto::STRING) {
int total_size = tensor_proto->string_data_size();
for (size_t i = 0; i < total_size; i++) {
@ -298,7 +298,7 @@ void writeOutput(
#endif
} else {
writeTextOutput<caffe2::CPUContext, caffe2::TensorCPU>(
workspace->GetBlob(name)->GetMutableTensor(caffe2::CPU),
BlobGetMutableTensor(workspace->GetBlob(name), caffe2::CPU),
output_prefix,
name);
}

View File

@ -137,7 +137,7 @@ int main(int argc, char** argv) {
if (blob == nullptr) {
blob = workspace->CreateBlob(input_names[i]);
}
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
if (input_type_list[i] == "uint8_t") {

View File

@ -12,7 +12,7 @@ namespace caffe2 {
namespace gloo {
void signalFailure(Blob* status_blob, std::exception& /* unused */) {
auto* res = status_blob->GetMutableTensor(CPU);
auto* res = BlobGetMutableTensor(status_blob, CPU);
res->Resize(1);
res->template mutable_data<int32_t>()[0] = 1;
}

View File

@ -22,7 +22,7 @@ static void AddConstInput(const std::vector<int>& shape, const float value,
option.set_device_type(PROTO_CUDA);
CUDAContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CUDA);
auto* tensor = BlobGetMutableTensor(blob, CUDA);
tensor->Resize(shape);
math::Set<float, CUDAContext>(tensor->size(), value,
tensor->mutable_data<float>(),

View File

@ -95,10 +95,10 @@ void BlobToTensorProto(
}
// Set values
if (blob->IsTensorType(CPU)) {
if (BlobIsTensorType(*blob, CPU)) {
const auto& cpu_tensor = blob->template Get<TensorCPU>();
CPUTensorToTensorProto(cpu_tensor, t);
} else if (blob->IsTensorType(CUDA)) {
} else if (BlobIsTensorType(*blob, CUDA)) {
const auto& cuda_tensor = blob->template Get<TensorCUDA>();
const auto cpu_tensor = TensorCPU(cuda_tensor, context);
context->FinishDeviceComputation();

View File

@ -6,16 +6,16 @@
#include <typeinfo>
#include <type_traits>
#include <vector>
#include "caffe2/core/blob_serializer_base.h"
#include "caffe2/core/common.h"
#include <ATen/core/typeid.h>
#include "caffe2/core/logging.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/typeid.h"
#include "caffe2/proto/caffe2_pb.h"
namespace caffe2 {
class Tensor;
/**
* @brief Blob is a general container that hosts a typed pointer.
*
@ -50,15 +50,6 @@ class CAFFE2_API Blob final {
return meta_.Match<T>();
}
bool IsTensorType(DeviceType device_type) const {
bool is_match = meta_.Match<Tensor>();
auto* tensor = static_cast<Tensor*>(pointer_);
if (is_match && tensor && tensor->GetDeviceType() == device_type) {
return true;
}
return false;
}
/**
* Returns the meta info of the blob.
*/
@ -109,9 +100,6 @@ class CAFFE2_API Blob final {
std::is_default_constructible<T>::value,
"GetMutable can't be called with non-default-constructible types. "
"Try using specialized methods");
static_assert(
!std::is_same<T, Tensor>::value,
"Use GetMutableTensor(DeviceType) instead");
if (IsType<T>()) {
return static_cast<T*>(pointer_);
} else {
@ -129,16 +117,6 @@ class CAFFE2_API Blob final {
}
}
inline Tensor* GetMutableTensor(DeviceType device_type) {
if (IsTensorType(device_type)) {
return static_cast<Tensor*>(pointer_);
} else {
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " DeviceType:" << device_type;
return Reset<Tensor>(new Tensor(device_type));
}
}
/**
* Sets the underlying object to the allocated one. The Blob then takes over
* the ownership of the passed in pointer. If there is already an object in
@ -248,5 +226,29 @@ inline void swap(Blob& lhs, Blob& rhs) {
lhs.swap(rhs);
}
inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) {
bool is_match = blob.meta().Match<Tensor>();
if (!is_match) {
return false;
}
const Tensor* tensor = &blob.Get<Tensor>();
return tensor && tensor->GetDeviceType() == device_type;
}
inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) {
if (blob->IsType<Tensor>()) {
Tensor* tensor = blob->GetMutable<Tensor>();
if (tensor->GetDeviceType() == device_type) {
return tensor;
}
}
// if we're here, then either Blob didn't hold a Tensor
// or that Tensor had the wrong DeviceType.
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " DeviceType:" << device_type;
return blob->Reset<Tensor>(new Tensor(device_type));
}
} // namespace caffe2
#endif // CAFFE2_CORE_BLOB_H_

View File

@ -132,7 +132,7 @@ TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
for (int i = 0; i < 6; ++i) { \
cpu_tensor.mutable_data<TypeParam>()[i] = static_cast<TypeParam>(i); \
} \
blob.GetMutableTensor(CUDA)->CopyFrom(cpu_tensor); \
BlobGetMutableTensor(&blob, CUDA)->CopyFrom(cpu_tensor); \
string serialized = SerializeBlob(blob, "test"); \
BlobProto proto; \
CAFFE_ENFORCE(proto.ParseFromString(serialized)); \
@ -149,7 +149,7 @@ TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
} \
Blob new_blob; \
EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \
EXPECT_TRUE(new_blob.IsTensorType(CUDA)); \
EXPECT_TRUE(BlobIsTensorType(new_blob, CUDA)); \
Tensor new_cpu_tensor(blob.Get<Tensor>(), CPU); \
EXPECT_EQ(new_cpu_tensor.ndim(), 2); \
EXPECT_EQ(new_cpu_tensor.dim(0), 2); \
@ -199,7 +199,7 @@ TEST(TensorTest, TensorSerializationMultiDevices) {
// Test if the restored blob is still of the same device.
blob.Reset();
EXPECT_NO_THROW(DeserializeBlob(serialized, &blob));
EXPECT_TRUE(blob.IsTensorType(CUDA));
EXPECT_TRUE(BlobIsTensorType(blob, CUDA));
EXPECT_EQ(GetGPUIDForPointer(blob.Get<TensorCUDA>().data<float>()),
gpu_id);
// Test if we force the restored blob on a different device, we
@ -207,7 +207,7 @@ TEST(TensorTest, TensorSerializationMultiDevices) {
blob.Reset();
proto.mutable_tensor()->mutable_device_detail()->set_cuda_gpu_id(0);
EXPECT_NO_THROW(DeserializeBlob(proto.SerializeAsString(), &blob));
EXPECT_TRUE(blob.IsTensorType(CUDA));
EXPECT_TRUE(BlobIsTensorType(blob, CUDA));
EXPECT_EQ(GetGPUIDForPointer(blob.Get<TensorCUDA>().data<float>()), 0);
}
}

View File

@ -363,7 +363,8 @@ void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
auto tensor_proto = blob_proto.tensor();
Deserialize(
tensor_proto,
blob->GetMutableTensor(
BlobGetMutableTensor(
blob,
static_cast<DeviceType>(tensor_proto.device_detail().device_type())));
}

View File

@ -86,15 +86,15 @@ TEST(BlobTest, Blob) {
int* int_unused CAFFE2_UNUSED = blob.GetMutable<int>();
EXPECT_TRUE(blob.IsType<int>());
EXPECT_FALSE(blob.IsType<BlobTestFoo>());
EXPECT_FALSE(blob.IsTensorType(CPU));
EXPECT_FALSE(BlobIsTensorType(blob, CPU));
BlobTestFoo* foo_unused CAFFE2_UNUSED = blob.GetMutable<BlobTestFoo>();
EXPECT_TRUE(blob.IsType<BlobTestFoo>());
EXPECT_FALSE(blob.IsType<int>());
EXPECT_FALSE(blob.IsTensorType(CPU));
EXPECT_FALSE(BlobIsTensorType(blob, CPU));
Tensor* tensor_unused CAFFE2_UNUSED = blob.GetMutableTensor(CPU);
EXPECT_TRUE(blob.IsTensorType(CPU));
Tensor* tensor_unused CAFFE2_UNUSED = BlobGetMutableTensor(&blob, CPU);
EXPECT_TRUE(BlobIsTensorType(blob, CPU));
EXPECT_FALSE(blob.IsType<BlobTestFoo>());
EXPECT_FALSE(blob.IsType<int>());
}
@ -600,7 +600,7 @@ TEST(TensorDeathTest, CannotCastDownLargeDims) {
#define TEST_SERIALIZATION_WITH_TYPE(TypeParam, field_name) \
TEST(TensorTest, TensorSerialization_##TypeParam) { \
Blob blob; \
Tensor* tensor = blob.GetMutableTensor(CPU); \
Tensor* tensor = BlobGetMutableTensor(&blob, CPU); \
tensor->Resize(2, 3); \
for (int i = 0; i < 6; ++i) { \
tensor->mutable_data<TypeParam>()[i] = static_cast<TypeParam>(i); \
@ -621,7 +621,7 @@ TEST(TensorDeathTest, CannotCastDownLargeDims) {
} \
Blob new_blob; \
EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \
EXPECT_TRUE(new_blob.IsTensorType(CPU)); \
EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); \
const TensorCPU& new_tensor = blob.Get<TensorCPU>(); \
EXPECT_EQ(new_tensor.ndim(), 2); \
EXPECT_EQ(new_tensor.dim(0), 2); \
@ -634,7 +634,7 @@ TEST(TensorDeathTest, CannotCastDownLargeDims) {
\
TEST(EmptyTensorTest, TensorSerialization_##TypeParam) { \
Blob blob; \
TensorCPU* tensor = blob.GetMutableTensor(CPU); \
TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU); \
tensor->Resize(0, 3); \
tensor->mutable_data<TypeParam>(); \
string serialized = SerializeBlob(blob, "test"); \
@ -650,7 +650,7 @@ TEST(TensorDeathTest, CannotCastDownLargeDims) {
EXPECT_EQ(tensor_proto.field_name##_size(), 0); \
Blob new_blob; \
EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \
EXPECT_TRUE(new_blob.IsTensorType(CPU)); \
EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); \
const TensorCPU& new_tensor = blob.Get<TensorCPU>(); \
EXPECT_EQ(new_tensor.ndim(), 2); \
EXPECT_EQ(new_tensor.dim(0), 0); \
@ -669,7 +669,7 @@ TEST_SERIALIZATION_WITH_TYPE(int64_t, int64_data)
TEST(TensorTest, TensorSerialization_CustomType) {
Blob blob;
TensorCPU* tensor = blob.GetMutableTensor(CPU);
TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU);
tensor->Resize(2, 3);
for (int i = 0; i < 6; ++i) {
tensor->mutable_data<BlobTestFoo>()[i].val = i;
@ -681,7 +681,7 @@ TEST(TensorTest, TensorSerialization_CustomType) {
EXPECT_EQ(proto.type(), "Tensor");
Blob new_blob;
EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob));
EXPECT_TRUE(new_blob.IsTensorType(CPU));
EXPECT_TRUE(BlobIsTensorType(new_blob, CPU));
const TensorCPU& new_tensor = blob.Get<TensorCPU>();
EXPECT_EQ(new_tensor.ndim(), 2);
EXPECT_EQ(new_tensor.dim(0), 2);
@ -696,7 +696,7 @@ TEST(TensorTest, TensorSerialization_CustomType) {
TEST(TensorTest, Half) {
const int64_t kSize = 3000000;
Blob blob;
TensorCPU* tensor = blob.GetMutableTensor(CPU);
TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU);
tensor->Resize(kSize);
for (int i = 0; i < tensor->size(); ++i) {
tensor->mutable_data<at::Half>()[i].x = i % 10000;
@ -724,7 +724,7 @@ TEST(TensorTest, Half) {
}
Blob new_blob;
EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob));
EXPECT_TRUE(new_blob.IsTensorType(CPU));
EXPECT_TRUE(BlobIsTensorType(new_blob, CPU));
const TensorCPU& new_tensor = blob.Get<TensorCPU>();
EXPECT_EQ(new_tensor.ndim(), 1);
EXPECT_EQ(new_tensor.dim(0), kSize);
@ -860,7 +860,7 @@ TYPED_TEST(TypedTensorTest, BigTensorSerialization) {
{
VLOG(1) << "Test begin";
Blob blob;
Tensor* tensor = blob.GetMutableTensor(CPU);
Tensor* tensor = BlobGetMutableTensor(&blob, CPU);
VLOG(1) << "Allocating blob";
tensor->Resize(d1, d2);
auto mutableData = tensor->mutable_data<TypeParam>();
@ -903,7 +903,7 @@ TYPED_TEST(TypedTensorTest, BigTensorSerialization) {
load_op->Run();
VLOG(1) << "Reading blob from workspace";
auto new_blob = ws.GetBlob("test");
EXPECT_TRUE(new_blob->IsTensorType(CPU));
EXPECT_TRUE(BlobIsTensorType(*new_blob, CPU));
const auto& new_tensor = new_blob->Get<TensorCPU>();
EXPECT_EQ(new_tensor.ndim(), d1);
@ -1030,7 +1030,7 @@ TEST(CustomChunkSize, BigTensorSerialization) {
int64_t size = d1 * d2;
Blob blob;
TensorCPU* tensor = blob.GetMutableTensor(CPU);
TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU);
tensor->Resize(d1, d2);
tensor->mutable_data<float>();
std::mutex mutex;

View File

@ -122,7 +122,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
static_assert(
std::is_same<T, Tensor>::value,
"Output(int, DeviceType) is only available for Tensor");
return outputs_.at(idx)->GetMutableTensor(type);
return BlobGetMutableTensor(outputs_.at(idx), type);
}
template <typename T>
@ -149,7 +149,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
}
inline bool InputIsTensorType(int idx, DeviceType device_type) {
return inputs_.at(idx)->IsTensorType(device_type);
return BlobIsTensorType(*inputs_.at(idx), device_type);
}
template <typename T>
@ -162,7 +162,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
}
inline bool OutputIsTensorType(int idx, DeviceType type) {
return outputs_.at(idx)->IsTensorType(type);
return BlobIsTensorType(*outputs_.at(idx), type);
}
inline int InputSize() const {

View File

@ -131,7 +131,8 @@ struct WorkspaceIdInjector {
"Integer overflow while calculating GLOBAL_WORKSPACE_ID blob");
int32_t global_ws_id = (seq_++) + (static_cast<int32_t>(node_id) << 16);
Blob* global_ws_id_blob = workspace->CreateLocalBlob(GLOBAL_WORKSPACE_ID);
TensorCPU* global_ws_id_tensor = global_ws_id_blob->GetMutableTensor(CPU);
TensorCPU* global_ws_id_tensor =
BlobGetMutableTensor(global_ws_id_blob, CPU);
global_ws_id_tensor->Resize();
global_ws_id_tensor->template mutable_data<int32_t>()[0] = global_ws_id;
VLOG(1) << "Adding " << GLOBAL_WORKSPACE_ID << " = " << global_ws_id;

View File

@ -151,7 +151,7 @@ class CAFFE2_API Workspace {
auto* to_blob = CreateBlob(blob);
CAFFE_ENFORCE(to_blob);
const auto& from_tensor = from_blob->template Get<Tensor>();
auto* to_tensor = to_blob->GetMutableTensor(Context::GetDeviceType());
auto* to_tensor = BlobGetMutableTensor(to_blob, Context::GetDeviceType());
to_tensor->CopyFrom(from_tensor);
}
}

View File

@ -33,8 +33,9 @@ class IDEEPConcatOp final : public IDEEPOperator {
if (OperatorBase::InputBlob(i).template IsType<itensor>()) {
inputs.emplace_back(Input(i));
} else {
CAFFE_ENFORCE(OperatorBase::InputBlob(i).IsTensorType(CPU),
"Expect cpu tensor if not itensor");
CAFFE_ENFORCE(
BlobIsTensorType(OperatorBase::InputBlob(i), CPU),
"Expect cpu tensor if not itensor");
auto& tensor_cpu = OperatorBase::Input<Tensor>(i, CPU);
CAFFE_ENFORCE(tensor_cpu.dims().size() == 0 ||
tensor_cpu.size_from_dim(0) == 0,

View File

@ -89,7 +89,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
local_input_blobs_[i]->Reset();
}
input_share_[i] = false;
auto dtensor = local_input_blobs_[i]->GetMutableTensor(CPU);
auto dtensor = BlobGetMutableTensor(local_input_blobs_[i], CPU);
dtensor->Resize(input.get_dims());
if (input.is_public_format()) {
dtensor->ShareExternalPointer(
@ -121,7 +121,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
continue;
}
CAFFE_ENFORCE(
local_output_blobs_[i]->IsTensorType(CPU),
BlobIsTensorType(*local_output_blobs_[i], CPU),
"IDEEP fallback op currently does not support non-TensorCPU "
"output type who needs copying.");
const auto& src = local_output_blobs_[i]->template Get<TensorCPU>();
@ -153,7 +153,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
VLOG(2) << "Output " << base_def_.output(i) << " as CPUTensor";
Blob* dst = OperatorBase::OutputBlob(i);
dst->Reset(new Tensor(CPU));
auto dtensor = dst->GetMutableTensor(CPU);
auto dtensor = BlobGetMutableTensor(dst, CPU);
dtensor->Resize(src_dims);
dtensor->ShareData(src);
}

View File

@ -31,7 +31,7 @@ class CopyIDEEPToCPUOp final : public IDEEPOperator {
USE_IDEEP_DEF_ALIASES();
bool RunOnDevice() override {
const auto& input_blob = OperatorBase::InputBlob(0);
if (input_blob.IsTensorType(CPU)) {
if (BlobIsTensorType(input_blob, CPU)) {
VLOG(2) << "Directing sharing of TensorCPU";
const auto& X = OperatorBase::Input<Tensor>(0, CPU);
auto* Y = OperatorBase::Output<Tensor>(0, CPU);

View File

@ -66,10 +66,10 @@ class MKLFallbackOp final : public Operator<MKLContext> {
for (int i = 0; i < InputSize(); ++i) {
if (OperatorBase::InputIsType<MKLMemory<float>>(i)) {
OperatorBase::Input<MKLMemory<float>>(i).CopyTo(
local_input_blobs_[i]->GetMutableTensor(CPU));
BlobGetMutableTensor(local_input_blobs_[i], CPU));
} else if (OperatorBase::InputIsType<MKLMemory<double>>(i)) {
OperatorBase::Input<MKLMemory<double>>(i).CopyTo(
local_input_blobs_[i]->GetMutableTensor(CPU));
BlobGetMutableTensor(local_input_blobs_[i], CPU));
} else {
VLOG(1) << "Input " << i << " is not MKLMemory. Skipping copy.";
// Note(jiayq): This removes a const but conceptually
@ -93,7 +93,7 @@ class MKLFallbackOp final : public Operator<MKLContext> {
continue;
}
CAFFE_ENFORCE(
local_output_blobs_[i]->IsTensorType(CPU),
BlobIsTensorType(*local_output_blobs_[i], CPU),
"MKL fallback op currently does not support non-TensorCPU "
"output type who needs copying.");
const auto& src = local_output_blobs_[i]->template Get<TensorCPU>();

View File

@ -43,7 +43,7 @@ bool CopyFromGLOp<T>::RunOnDevice() {
if (first_run_) {
first_run_ = false;
for (int i = 0; i < Inputs().size(); ++i) {
auto* Y = OperatorBase::Outputs()[i]->GetMutableTensor(CPU);
auto* Y = BlobGetMutableTensor(OperatorBase::Outputs()[i], CPU);
Y->Resize(inputs_[i]->dims());
Y->template mutable_data<float>();
}
@ -54,7 +54,7 @@ bool CopyFromGLOp<T>::RunOnDevice() {
// GLTensor
auto* X = inputs_[i].get();
X->lazy_allocate(Xblob, second_run_, true);
auto* Y = OperatorBase::Outputs()[i]->GetMutableTensor(CPU);
auto* Y = BlobGetMutableTensor(OperatorBase::Outputs()[i], CPU);
Timer timer;
timer.Start();
getTensorCPU(*X, *Y);

View File

@ -27,7 +27,7 @@ template<typename T = float>
void PopulateCPUBlob(Workspace *ws, bool random, std::string name,
std::vector<int> dims, int val = 1, int dist_shift = 0, float variance = 1) {
Blob *blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(dims);
T *t_data = tensor->mutable_data<T>();
std::random_device rd;

View File

@ -489,13 +489,13 @@ class MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocessOp final
"noise_size", 491 /* prime to avoid artifacts */);
// Treaded as half4 in the kernel, so need half4 here.
noiseSize = divRoundUp(noiseSize, 4) * 4;
if (!noiseBlob->IsTensorType(CPU) ||
if (!BlobIsTensorType(*noiseBlob, CPU) ||
noiseBlob->Get<TensorCPU>().size() != noiseSize) {
VLOG(2) << "Initializing stylizer with noise: " << noiseSize;
caffe2::Timer rt;
// Initialize random noise on first use.
// Cache it to maintain temporal consistency.
auto* t = noiseBlob->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(noiseBlob, CPU);
t->Resize(noiseSize);
math::RandGaussian<float, CPUContext>(
t->size(),

View File

@ -94,7 +94,7 @@ void testMPSCNN() {
Workspace ws;
for (auto i = 0; i < N; ++i) {
auto* t = ws.CreateBlob(cpu(i))->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob(cpu(i)), CPU);
t->Resize(BS, C, H, W);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -152,7 +152,7 @@ void testMPSCNN() {
Workspace ws;
for (auto i = 0; i < N; ++i) {
auto* t = ws.CreateBlob(cpu(i))->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob(cpu(i)), CPU);
switch (ndim) {
case 1:
t->Resize(5);
@ -210,7 +210,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNNormalizePlanarYUV Test: ";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(batch_size, channels, 8, 13);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -218,14 +218,14 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
t->Resize(1, channels);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
t->size(), 0, 1, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws.CreateBlob("stddev")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("stddev"), CPU);
t->Resize(1, channels);
CPUContext ctx;
math::RandUniform<float, CPUContext>(
@ -290,7 +290,7 @@ void testMPSCNN() {
for (const auto dim : {10, 40}) {
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(batchSize, channels, dim, dim);
CPUContext ctx;
// Too noisy.
@ -299,7 +299,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(channels);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -310,7 +310,7 @@ void testMPSCNN() {
// t->mutable_data<float>(), &ctx);
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(channels);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -321,7 +321,7 @@ void testMPSCNN() {
// t->mutable_data<float>(), &ctx);
}
{
auto* t = ws.CreateBlob("pw")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("pw"), CPU);
t->Resize(prelu == PreluTy::SHARED ? 1 : channels);
CPUContext ctx;
// Too noisy.
@ -409,7 +409,7 @@ void testMPSCNN() {
Workspace ws;
const auto channels = array ? 12 : 3;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(batch_size, channels, 8, 13);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -417,7 +417,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(shared ? channels : 1);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -480,7 +480,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNSpatialBN Test: " << channels;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(batch_size, channels, 8, 13);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -488,7 +488,7 @@ void testMPSCNN() {
}
for (const std::string name : {"scale", "bias", "mean", "var"}) {
auto* t = ws.CreateBlob(name)->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob(name), CPU);
t->Resize(channels);
CPUContext ctx;
// High mean to avoid var division by zero.
@ -575,7 +575,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNFC Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(batchSize, CIn, H, W);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -583,7 +583,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(COut, CIn * H * W);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -591,7 +591,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(COut);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -682,8 +682,8 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNPool Test: " << pool;
Workspace ws;
{
auto* t =
ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(
ws.CreateBlob("X_cpu"), CPU);
t->Resize(batchSize, 8, 8, 13);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -784,7 +784,7 @@ void testMPSCNN() {
std::vector<std::vector<size_t>>{{1, 3, 50, 80}, {1, 12, 50, 80}}) {
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(dims);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -860,7 +860,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNPreprocess Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(1, 8, 13, 4);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -869,7 +869,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
t->Resize(3);
CPUContext ctx;
t->mutable_data<float>()[0] = 100;
@ -940,7 +940,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNDeprocess Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(1, 3, 8, 24);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -949,7 +949,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
t->Resize(3);
CPUContext ctx;
t->mutable_data<float>()[0] = 100;
@ -999,7 +999,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNDeprocess Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(1, 3, 1280, 720);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -1008,7 +1008,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
t->Resize(3);
CPUContext ctx;
t->mutable_data<float>()[0] = 30;
@ -1072,7 +1072,8 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNConv Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t =
BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(batchSize, 12, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1080,7 +1081,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(8, 12, kernel_h, kernel_w);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1092,7 +1093,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(8);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1188,7 +1189,7 @@ void testMPSCNN() {
Workspace ws;
int output_channels = input_channels * channel_multiplier;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(batchSize, input_channels, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1196,7 +1197,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(output_channels, 1, 3, 3);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1204,7 +1205,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(output_channels);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1275,7 +1276,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNConvRelu Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(1, 12, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1283,7 +1284,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(8, 12, 3, 3);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1291,7 +1292,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(8);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1385,7 +1386,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSConv Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(1, 12, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1393,7 +1394,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(8, 12, 3, 3);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1401,7 +1402,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(8);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1493,7 +1494,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSConv Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(batchSize, C, 12, 16);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1501,7 +1502,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(M, C, K, K);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1509,7 +1510,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(M);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1607,7 +1608,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNConv Test - group";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(batchSize, C, 12, 16);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1615,7 +1616,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(M, C / group, K, K);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1623,7 +1624,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(M);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1726,7 +1727,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNMul Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X0_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X0_cpu"), CPU);
t->Resize(1, 12, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1734,7 +1735,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("X1_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X1_cpu"), CPU);
t->Resize(72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1791,7 +1792,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNSub Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X0_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X0_cpu"), CPU);
t->Resize(1, 12, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1799,7 +1800,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("X1_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X1_cpu"), CPU);
t->Resize(72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1856,7 +1857,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSAdd Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X0_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X0_cpu"), CPU);
t->Resize(1, 12, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1864,7 +1865,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("X1_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X1_cpu"), CPU);
t->Resize(1, 12, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1921,7 +1922,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSAdd Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X0_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X0_cpu"), CPU);
t->Resize(1, 12, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -1929,7 +1930,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("X1_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X1_cpu"), CPU);
t->Resize(1, 12, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -2011,7 +2012,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNNeuron Test: " << n;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(1, 4, 12, 12);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -2065,7 +2066,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNDropout Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(1, 12, 57, 72);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -2136,7 +2137,7 @@ void testMPSCNN() {
<< " - scale: " << scale;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(1, channels, 40, 40);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -2144,7 +2145,7 @@ void testMPSCNN() {
}
{
// Use the batch-first encoding (n, [bbox])
auto* t = ws.CreateBlob("R")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("R"), CPU);
t->Resize(6, 5);
for (auto i = 0; i < t->dim32(0); ++i) {
t->mutable_data<float>()[5 * i + 0] = 0; // batch
@ -2250,14 +2251,14 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNRoIWarp Test 2";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(1, 8, 40, 40);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
t->size(), 4, 2, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws.CreateBlob("R")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("R"), CPU);
t->Resize(6, 4);
for (auto i = 0; i < t->dim32(0); ++i) {
t->mutable_data<float>()[4 * i + 0] = (i % 4 + 1) * 1.0 / scale;
@ -2362,7 +2363,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNResizeNearestOp Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, 37, 89);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -2497,7 +2498,7 @@ void testMPSCNN() {
vector<float> im_info{60, 80, 0.166667};
vector<float> anchors{-38, -16, 53, 31, -120, -120, 135, 135};
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(num_images, A, H, W);
for (auto i = 0; i < t->size(); ++i) {
t->mutable_data<float>()[i] = scores[i];
@ -2505,7 +2506,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("bbox_delta_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("bbox_delta_cpu"), CPU);
t->Resize(num_images, 4 * A, H, W);
for (auto i = 0; i < t->size(); ++i) {
t->mutable_data<float>()[i] = bbx[i];
@ -2513,7 +2514,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("im_info")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("im_info"), CPU);
t->Resize(num_images, 3);
for (auto i = 0; i < t->size(); ++i) {
t->mutable_data<float>()[i] = im_info[i];
@ -2521,7 +2522,7 @@ void testMPSCNN() {
}
{
auto* t = ws.CreateBlob("anchors")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("anchors"), CPU);
t->Resize(A, 4);
for (auto i = 0; i < t->size(); ++i) {
t->mutable_data<float>()[i] = anchors[i];
@ -2587,7 +2588,7 @@ void testMPSCNN() {
LOG(INFO) << "MPSCNNSoftmax Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
// Only works for spatial dimension of (1, 1) - weird.
t->Resize(batchSize, 12, 1, 1);
CPUContext ctx;
@ -2661,8 +2662,8 @@ void testMPSCNN() {
LOG(INFO) << "MPSConvTranspose Test";
Workspace ws;
{
auto* t =
ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(
ws.CreateBlob("X_cpu"), CPU);
t->Resize(batchSize, inputChannels, 8, 12);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -2675,7 +2676,7 @@ void testMPSCNN() {
{
auto* t =
ws.CreateBlob("W")->GetMutableTensor(CPU);
BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(
inputChannels,
outputChannels,
@ -2692,7 +2693,7 @@ void testMPSCNN() {
{
auto* t =
ws.CreateBlob("b")->GetMutableTensor(CPU);
BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(outputChannels);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -2809,7 +2810,7 @@ void testMPSCNN() {
<< batchSize;
Workspace ws;
for (auto i = 0; i < numInputs; ++i) {
auto* t = ws.CreateBlob(cpu(i))->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob(cpu(i)), CPU);
t->Resize(batchSize, array ? (i + 1) * 4 : 4, 10, 10);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -2891,7 +2892,7 @@ void testMPSCNN() {
}
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(batchSize, inputChannels, 53, 47);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -2964,7 +2965,7 @@ void testMPSCNN() {
<< numInputs << ", " << batchSize;
Workspace ws;
for (auto i = 0; i < numInputs; ++i) {
auto* t = ws.CreateBlob(cpu(i))->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob(cpu(i)), CPU);
t->Resize(batchSize, channelCount, 9, 17);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -3336,8 +3337,8 @@ void compareModels(const NetDef& initNet, NetDef predictNet) {
Workspace cws;
cws.RunNetOnce(initNet);
{
auto* t =
cws.CreateBlob(predictNet.external_input(0))->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(
cws.CreateBlob(predictNet.external_input(0)), CPU);
t->Resize(1, 224, 224, 4);
for (auto i = 0; i < t->size(); ++i) {
t->mutable_data<uint8_t>()[i] = i % 225;
@ -3348,8 +3349,8 @@ void compareModels(const NetDef& initNet, NetDef predictNet) {
Workspace mws;
mws.RunNetOnce(initNet);
{
auto* t =
mws.CreateBlob(predictNet.external_input(0))->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(
mws.CreateBlob(predictNet.external_input(0)), CPU);
t->Resize(1, 224, 224, 4);
for (auto i = 0; i < t->size(); ++i) {
t->mutable_data<uint8_t>()[i] = i % 225;
@ -3397,16 +3398,16 @@ void verifyRewrite(
dumpDef(predictNet);
dumpDef(metalPredictNet);
#define RUN_NET(ws, predictNet) \
ws.RunNetOnce(initNet); \
{ \
auto* t = \
ws.CreateBlob(predictNet.external_input(0))->GetMutableTensor(CPU); \
t->Resize(inputDims); \
CPUContext ctx; \
math::RandGaussian<float, CPUContext>( \
t->size(), 0, 1, t->mutable_data<float>(), &ctx); \
} \
#define RUN_NET(ws, predictNet) \
ws.RunNetOnce(initNet); \
{ \
auto* t = BlobGetMutableTensor( \
ws.CreateBlob(predictNet.external_input(0)), CPU); \
t->Resize(inputDims); \
CPUContext ctx; \
math::RandGaussian<float, CPUContext>( \
t->size(), 0, 1, t->mutable_data<float>(), &ctx); \
} \
ws.RunNetOnce(predictNet);
// initialize

View File

@ -16,7 +16,7 @@ void AddNoiseInput(const vector<int64_t>& shape, const string& name, Workspace*
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
math::RandGaussian<float, CPUContext>(

View File

@ -16,7 +16,7 @@ void AddNoiseInput(const vector<int64_t>& shape, const string& name, Workspace*
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
math::RandGaussian<float, CPUContext>(

View File

@ -679,7 +679,7 @@ void NNApi::init(const TensorVector& inputs, TensorVector* outputs) {
output_dims.push_back(dim);
}
auto* tensor = ws_.CreateBlob(blob)->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(ws_.CreateBlob(blob), CPU);
tensor->Resize(output_dims);
outputs->push_back(tensor);

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/core/init.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
@ -43,14 +43,14 @@ static double benchmark_conv_caffe2(
ws = &localWs;
}
{
auto* t = ws->CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws->CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws->CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws->CreateBlob("W"), CPU);
if (group == 1) {
t->Resize(K, C, kernel, kernel);
} else {
@ -61,7 +61,7 @@ static double benchmark_conv_caffe2(
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws->CreateBlob("B")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws->CreateBlob("B"), CPU);
t->Resize(K);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -129,14 +129,14 @@ static double benchmark_conv_nnapi(
ws = &localWs;
}
{
auto* t = ws->CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws->CreateBlob("X_cpu"), CPU);
t->Resize(N, H, W, C);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws->CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws->CreateBlob("W"), CPU);
if (group > 1) {
CAFFE_ENFORCE_EQ(C, group);
t->Resize(1, kernel, kernel, C);
@ -148,7 +148,7 @@ static double benchmark_conv_nnapi(
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws->CreateBlob("B")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws->CreateBlob("B"), CPU);
t->Resize(K);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -190,7 +190,7 @@ static double benchmark_conv_nnapi(
NetDef initNet;
NNApi model(initNet, netdef, ws);
std::vector<TensorCPU*> inputs, outputs;
inputs.push_back(ws->GetBlob("X_cpu")->GetMutableTensor(CPU));
inputs.push_back(BlobGetMutableTensor(ws->GetBlob("X_cpu"), CPU));
CAFFE_ENFORCE(model.run(inputs, &outputs));
for (int i = 0; i < warmup; i++) {
@ -220,14 +220,14 @@ static double benchmark_conv_nnapi_int8(
ws = &localWs;
}
{
auto* t = ws->CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws->CreateBlob("X_cpu"), CPU);
t->Resize(N, H, W, C);
for (int i = 0; i < t->size(); i++) {
t->mutable_data<uint8_t>()[i] = rand() % 10;
}
}
{
auto* t = ws->CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws->CreateBlob("W"), CPU);
if (group > 1) {
CAFFE_ENFORCE_EQ(C, group);
t->Resize(1, kernel, kernel, C);
@ -243,7 +243,7 @@ static double benchmark_conv_nnapi_int8(
// should be of ANEURALNETWORKS_TENSOR_INT32, with zeroPoint of 0 and
// bias_scale == input_scale * filter_scale.
{
auto* t = ws->CreateBlob("B")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws->CreateBlob("B"), CPU);
t->Resize(K);
for (int i = 0; i < t->size(); i++) {
t->mutable_data<int32_t>()[i] = rand() % 10;
@ -322,7 +322,7 @@ static double benchmark_conv_nnapi_int8(
NetDef initNet;
NNApi model(initNet, netdef, ws);
std::vector<TensorCPU*> inputs, outputs;
inputs.push_back(ws->GetBlob("X_cpu")->GetMutableTensor(CPU));
inputs.push_back(BlobGetMutableTensor(ws->GetBlob("X_cpu"), CPU));
CAFFE_ENFORCE(model.run(inputs, &outputs));
for (int i = 0; i < warmup; i++) {

View File

@ -55,7 +55,7 @@ static void test_relu(int N, int C, int H, int W) {
// CPU reference
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, H, W, C);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -81,7 +81,7 @@ static void test_relu(int N, int C, int H, int W) {
NetDef initNet;
NNApi model(initNet, netdef, &ws);
std::vector<TensorCPU*> inputs, outputs;
inputs.push_back(ws.GetBlob("X_cpu")->GetMutableTensor(CPU));
inputs.push_back(BlobGetMutableTensor(ws.GetBlob("X_cpu"), CPU));
EXPECT_TRUE(model.run(inputs, &outputs));
const auto& t_nn = *outputs[0];
@ -103,21 +103,21 @@ static void test_conv_NHWC(
int stride_w) {
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, H, W, C);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(K, kernel, kernel, C);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws.CreateBlob("B")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("B"), CPU);
t->Resize(K);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -189,7 +189,7 @@ static void test_conv_NHWC(
NetDef initNet;
NNApi model(initNet, netdef, &ws);
std::vector<TensorCPU*> inputs, outputs;
inputs.push_back(ws.GetBlob("X_cpu")->GetMutableTensor(CPU));
inputs.push_back(BlobGetMutableTensor(ws.GetBlob("X_cpu"), CPU));
EXPECT_TRUE(model.run(inputs, &outputs));
const auto& t_nn = *outputs[0];
@ -211,21 +211,21 @@ static void test_depthwise_conv_NHWC(
int stride_w) {
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, H, W, C);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(1, kernel, kernel, D);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws.CreateBlob("B")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("B"), CPU);
t->Resize(D);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -406,7 +406,7 @@ static void test_depthwise_conv_NHWC(
NetDef initNet;
NNApi model(initNet, netdef, &ws);
std::vector<TensorCPU*> inputs, outputs;
inputs.push_back(ws.GetBlob("X_cpu")->GetMutableTensor(CPU));
inputs.push_back(BlobGetMutableTensor(ws.GetBlob("X_cpu"), CPU));
EXPECT_TRUE(model.run(inputs, &outputs));
const auto& t_nn = *outputs[0];
@ -428,7 +428,7 @@ static void test_pooling(
int stride_w) {
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, H, W, C);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(
@ -496,7 +496,7 @@ static void test_pooling(
NetDef initNet;
NNApi model(initNet, netdef, &ws);
std::vector<TensorCPU*> inputs, outputs;
inputs.push_back(ws.GetBlob("X_cpu")->GetMutableTensor(CPU));
inputs.push_back(BlobGetMutableTensor(ws.GetBlob("X_cpu"), CPU));
EXPECT_TRUE(model.run(inputs, &outputs));
const auto& t_nn = *outputs[0];
@ -506,7 +506,7 @@ static void test_pooling(
static void test_softmax(int N, int C, int H = 1, int W = 1) {
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
if (H == 1 && W == 1) {
t->Resize(N, C);
} else {
@ -538,7 +538,7 @@ static void test_softmax(int N, int C, int H = 1, int W = 1) {
NetDef initNet;
NNApi model(initNet, netdef, &ws);
std::vector<TensorCPU*> inputs, outputs;
inputs.push_back(ws.GetBlob("X_cpu")->GetMutableTensor(CPU));
inputs.push_back(BlobGetMutableTensor(ws.GetBlob("X_cpu"), CPU));
EXPECT_TRUE(model.run(inputs, &outputs));
const auto& t_nn = *outputs[0];

View File

@ -178,7 +178,7 @@ void testOpenGLCopyOps(int N, int C, int H, int W, float error, int tile_x = 1,
LOG(INFO) << "OPENGLCopyFrom/To Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(t->size(), 0, 1, t->mutable_data<float>(), &ctx);
@ -275,7 +275,7 @@ void testOpenGLConv(int N,
<< " Op: " << glPoolOperationName[poolOp];
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
if (random_input) {
@ -301,7 +301,7 @@ void testOpenGLConv(int N,
}
if (poolOp != AveragePool && poolOp != MaxPool) {
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
if (poolOp != ConvTranspose && poolOp != ConvTransposePRelu && poolOp != ConvTransposeRelu) {
t->Resize(K, C, kernel_h, kernel_w);
} else {
@ -343,7 +343,7 @@ void testOpenGLConv(int N,
// bias
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(K);
CPUContext ctx;
if (random_input) {
@ -367,7 +367,7 @@ void testOpenGLConv(int N,
}
if (poolOp == ConvPRelu || poolOp == ConvTransposePRelu) {
auto* t = ws.CreateBlob("p")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("p"), CPU);
t->Resize(K);
CPUContext ctx;
if (random_input) {
@ -532,7 +532,7 @@ void testOpenGLPRelu(
<< "C: " << C << ", H: " << H << ", W: " << W;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
// Too noisy.
@ -541,7 +541,7 @@ void testOpenGLPRelu(
// prelu scale
{
auto* t = ws.CreateBlob("p")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("p"), CPU);
t->Resize(prelu_size);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(t->size(), 0, 1, t->mutable_data<float>(), &ctx);
@ -603,7 +603,7 @@ void testOpenGLRelu(int N, int C, int H, int W, int input_tile_x, int input_tile
<< "C: " << C << ", H: " << H << ", W: " << W;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
// Too noisy.
@ -664,13 +664,13 @@ void testOpenGLAdd(int N, int C, int H, int W, float error = 0.1, int input_tile
<< "C: " << C << ", H: " << H << ", W: " << W;
Workspace ws;
{
auto* t0 = ws.CreateBlob("X_cpu0")->GetMutableTensor(CPU);
auto* t0 = BlobGetMutableTensor(ws.CreateBlob("X_cpu0"), CPU);
t0->Resize(N, C, H, W);
CPUContext ctx0;
// Too noisy.
math::RandGaussian<float, CPUContext>(t0->size(), 0, 30, t0->mutable_data<float>(), &ctx0);
auto* t1 = ws.CreateBlob("X_cpu1")->GetMutableTensor(CPU);
auto* t1 = BlobGetMutableTensor(ws.CreateBlob("X_cpu1"), CPU);
t1->Resize(N, C, H, W);
CPUContext ctx1;
// Too noisy.
@ -750,13 +750,13 @@ void testOpenGLSub(int N, int C, int H, int W, float error = 0.1) {
Workspace ws;
{
auto* t0 = ws.CreateBlob("X_cpu0")->GetMutableTensor(CPU);
auto* t0 = BlobGetMutableTensor(ws.CreateBlob("X_cpu0"), CPU);
t0->Resize(N, C, H, W);
CPUContext ctx0;
// Too noisy.
math::RandGaussian<float, CPUContext>(t0->size(), 0, 30, t0->mutable_data<float>(), &ctx0);
auto* t1 = ws.CreateBlob("X_cpu1")->GetMutableTensor(CPU);
auto* t1 = BlobGetMutableTensor(ws.CreateBlob("X_cpu1"), CPU);
t1->Resize(N, C, H, W);
CPUContext ctx1;
// Too noisy.
@ -814,8 +814,8 @@ void testOpenGLConcat(int N, std::vector<int> Cs, int H, int W, bool tiling = fa
<< "H: " << H << ", W: " << W;
Workspace ws;
for (int i = 0; i < Cs.size(); i++) {
auto* t =
ws.CreateBlob("X_cpu" + caffe2::to_string(i))->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(
ws.CreateBlob("X_cpu" + caffe2::to_string(i)), CPU);
t->Resize(N, Cs[i], H, W);
CPUContext ctx0;
// Too noisy.
@ -891,7 +891,7 @@ void testOpenGLSigmoid(int N, int C, int H, int W, float error) {
<< "C: " << C << ", H: " << H << ", W: " << W;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
// Too noisy.
@ -942,7 +942,7 @@ void testOpenGLTanh(int N, int C, int H, int W, float error) {
<< "C: " << C << ", H: " << H << ", W: " << W;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(t->size(), 0, 2, t->mutable_data<float>(), &ctx);
@ -992,14 +992,14 @@ void testOpenGLMul(int N, int C, int H, int W, float error) {
<< "C: " << C << ", H: " << H << ", W: " << W;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(t->size(), -10, 10, t->mutable_data<float>(), &ctx);
}
{
auto* t = ws.CreateBlob("B")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("B"), CPU);
t->Resize(1);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(t->size(), -10, 10, t->mutable_data<float>(), &ctx);
@ -1060,7 +1060,7 @@ void testOpenGLSoftmax(int N, int D, float error, bool tiled = false) {
LOG(INFO) << "OpenGL Softmax Test "
<< "N: " << N << " D: " << D << " Tiled:" << tiled;
Workspace ws;
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
{
t->Resize(N, D);
CPUContext ctx;
@ -1151,7 +1151,7 @@ void testOpenGLInstanceNorm(int N, int C, int H, int W, float error) {
<< "C: " << C << ", H: " << H << ", W: " << W;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
// Too noisy.
@ -1163,7 +1163,7 @@ void testOpenGLInstanceNorm(int N, int C, int H, int W, float error) {
// scale
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(C);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -1172,7 +1172,7 @@ void testOpenGLInstanceNorm(int N, int C, int H, int W, float error) {
}
// bias
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(C);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -1254,7 +1254,7 @@ void testOpenGLInstanceNormPRelu(int N, int C, int H, int W, float error) {
<< "C: " << C << ", H: " << H << ", W: " << W;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
// Too noisy.
@ -1266,7 +1266,7 @@ void testOpenGLInstanceNormPRelu(int N, int C, int H, int W, float error) {
// scale
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(C);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -1275,7 +1275,7 @@ void testOpenGLInstanceNormPRelu(int N, int C, int H, int W, float error) {
}
// bias
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(C);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -1284,7 +1284,7 @@ void testOpenGLInstanceNormPRelu(int N, int C, int H, int W, float error) {
}
// prelu scale
{
auto* t = ws.CreateBlob("p")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("p"), CPU);
t->Resize(C);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(t->size(), 0, 1, t->mutable_data<float>(), &ctx);
@ -1385,7 +1385,7 @@ void OpenGL_speedtest(int N,
<< " C: " << C << " H: " << H << " W: " << W;
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
if (random_input) {
@ -1399,7 +1399,7 @@ void OpenGL_speedtest(int N,
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(K, C, kernel_h, kernel_w);
CPUContext ctx;
if (random_input) {
@ -1413,7 +1413,7 @@ void OpenGL_speedtest(int N,
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(K);
CPUContext ctx;
if (random_input) {
@ -1479,7 +1479,7 @@ void testOpenGLPadImage(
{
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(t->size(), 0, 1, t->mutable_data<float>(), &ctx);
@ -1593,7 +1593,7 @@ void testOpenGLResize(int N,
{
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
math::RandGaussian<float, CPUContext>(t->size(), 0, 1, t->mutable_data<float>(), &ctx);
@ -1675,7 +1675,7 @@ void testOpenGLPreprocess(int N, int C, int H, int W, float error) {
LOG(INFO) << "OpenGL Preprocess Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, H, W, C);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -1684,7 +1684,7 @@ void testOpenGLPreprocess(int N, int C, int H, int W, float error) {
}
{
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
t->Resize(3);
CPUContext ctx;
t->mutable_data<float>()[0] = 100;
@ -1748,7 +1748,7 @@ void testOpenGLDeprocess(int N, int C, int H, int W, float error) {
LOG(INFO) << "OpenGLDeprocess Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -1757,7 +1757,7 @@ void testOpenGLDeprocess(int N, int C, int H, int W, float error) {
}
{
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
t->Resize(3);
CPUContext ctx;
t->mutable_data<float>()[0] = 30;
@ -1800,7 +1800,7 @@ void testOpenGLNormPlanarYUV(int N, int C, int H, int W, float error) {
LOG(INFO) << "OpenGLNormPlanarYUV Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, 3, H, W);
CPUContext ctx;
for (auto i = 0; i < t->size(); ++i) {
@ -1809,7 +1809,7 @@ void testOpenGLNormPlanarYUV(int N, int C, int H, int W, float error) {
}
{
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
t->Resize(1, 3);
CPUContext ctx;
t->mutable_data<float>()[0] = 30;
@ -1818,7 +1818,7 @@ void testOpenGLNormPlanarYUV(int N, int C, int H, int W, float error) {
}
{
auto* t = ws.CreateBlob("stdev")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("stdev"), CPU);
t->Resize(1, 3);
CPUContext ctx;
t->mutable_data<float>()[0] = 6;
@ -1879,7 +1879,7 @@ void OpenGL_copyops_speedtest(int N,
LOG(INFO) << "OpenGL CopyOps Speed Test";
Workspace ws;
{
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
t->Resize(N, C, H, W);
CPUContext ctx;
if (random_input) {
@ -1893,7 +1893,7 @@ void OpenGL_copyops_speedtest(int N,
}
{
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
t->Resize(K, C, kernel_h, kernel_w);
CPUContext ctx;
if (random_input) {
@ -1907,7 +1907,7 @@ void OpenGL_copyops_speedtest(int N,
}
{
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
t->Resize(K);
CPUContext ctx;
if (random_input) {
@ -1990,8 +1990,8 @@ void compareModelsForOpenGL(std::string name,
Workspace cws;
cws.RunNetOnce(initNet);
auto* t_cpu = cws.CreateBlob(truncatedPredictNet.external_input(0))
->GetMutableTensor(CPU);
auto* t_cpu = BlobGetMutableTensor(
cws.CreateBlob(truncatedPredictNet.external_input(0)), CPU);
if (name == "styleTransfer") {
CAFFE_ENFORCE_EQ(input_order, "NHWC");
CAFFE_ENFORCE_EQ(input_type, "uint8_t");
@ -2032,8 +2032,8 @@ void compareModelsForOpenGL(std::string name,
Workspace mws;
mws.RunNetOnce(initNet);
auto* t_gl = mws.CreateBlob(truncatedOpenGLPredictNet.external_input(0))
->GetMutableTensor(CPU);
auto* t_gl = BlobGetMutableTensor(
mws.CreateBlob(truncatedOpenGLPredictNet.external_input(0)), CPU);
if (name == "styleTransfer") {
CAFFE_ENFORCE_EQ(input_order, "NHWC");
CAFFE_ENFORCE_EQ(input_type, "uint8_t");
@ -2116,7 +2116,7 @@ void compareBatchedToTiledModels(std::string name,
tws.RunNetOnce(initNet);
auto* t_batch =
tws.CreateBlob(bachedNet.external_input(0))->GetMutableTensor(CPU);
BlobGetMutableTensor(tws.CreateBlob(bachedNet.external_input(0)), CPU);
if (name == "styleTransfer") {
CAFFE_ENFORCE_EQ(input_order, "NHWC");
CAFFE_ENFORCE_EQ(input_type, "uint8_t");
@ -2143,7 +2143,7 @@ void compareBatchedToTiledModels(std::string name,
bws.RunNetOnce(initNet);
auto* t_tiling =
bws.CreateBlob(tiledNet.external_input(0))->GetMutableTensor(CPU);
BlobGetMutableTensor(bws.CreateBlob(tiledNet.external_input(0)), CPU);
if (name == "styleTransfer") {
CAFFE_ENFORCE_EQ(input_order, "NHWC");
CAFFE_ENFORCE_EQ(input_type, "uint8_t");

View File

@ -14,7 +14,7 @@
#define POPULATE_DATA(_n, _s, _l) \
do { \
Blob* _blob = ws.CreateBlob((_n)); \
auto* _tensor = _blob->GetMutableTensor(CPU); \
auto* _tensor = BlobGetMutableTensor(_blob, CPU); \
_tensor->Resize((_s)); \
memcpy(_tensor->mutable_data<float>(), data_##_l, _tensor->nbytes()); \
} while (0)
@ -23,7 +23,7 @@
#define POPULATE_DATA(_n, _s, _l) \
do { \
Blob* _blob = ws.CreateBlob((_n)); \
auto* _tensor = _blob->GetMutableTensor(CPU); \
auto* _tensor = BlobGetMutableTensor(_blob, CPU); \
_tensor->Resize((_s)); \
memset(_tensor->mutable_data<float>(), 1, _tensor->nbytes()); \
} while (0)
@ -43,7 +43,7 @@ void AddConstInput(const vector<int64_t>& shape,
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
math::Set<float, CPUContext>(tensor->size(), value,
tensor->mutable_data<float>(),
@ -56,7 +56,7 @@ void AddNoiseInput(const vector<int64_t>& shape,
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
math::RandGaussian<float, CPUContext>(

View File

@ -289,13 +289,13 @@ void ConvTest2b1b(int IC, int KH, int KW, int H, int W, int OC, int N, ConvArgs
def.add_arg()->CopyFrom(MakeArgument("pad_r", args.pad_r));
def.add_arg()->CopyFrom(MakeArgument("pad_t", args.pad_t));
def.add_arg()->CopyFrom(MakeArgument("pad_b", args.pad_b));
auto* Xws = ws.CreateBlob("X")->GetMutableTensor(CPU);
auto* Xws = BlobGetMutableTensor(ws.CreateBlob("X"), CPU);
Xws->ResizeLike(X);
Xws->ShareExternalPointer(X.mutable_data<float>(), X.size());
auto* Wws = ws.CreateBlob("W")->GetMutableTensor(CPU);
auto* Wws = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
Wws->ResizeLike(W_);
Wws->ShareExternalPointer(W_.mutable_data<float>(), W_.size());
auto* bws = ws.CreateBlob("b")->GetMutableTensor(CPU);
auto* bws = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
bws->ResizeLike(bias);
bws->ShareExternalPointer(bias.mutable_data<float>(), bias.size());
ws.RunOperatorOnce(def);

View File

@ -30,7 +30,7 @@ class BatchMatMulOpGPUTest : public testing::Test {
const float value,
const string& name) {
Blob* blob = ws_.CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CUDA);
auto* tensor = BlobGetMutableTensor(blob, CUDA);
tensor->Resize(dims);
math::Set<float, CUDAContext>(
tensor->size(),

View File

@ -24,7 +24,7 @@ class BatchMatMulOpTest : public testing::Test {
const float value,
const string& name) {
Blob* blob = ws_.CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(dims);
math::Set<float, CPUContext>(
tensor->size(),

View File

@ -16,7 +16,7 @@ static void AddScalarInput(
Workspace* ws,
bool isEmpty = false) {
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
if (!isEmpty) {
tensor->Resize(vector<int64_t>{1});
*(tensor->template mutable_data<DataT>()) = value;

View File

@ -27,8 +27,8 @@ void runWithSharedBuffer<CPUContext>(
auto* mutexPtr = mutexBlob->GetMutable<std::unique_ptr<std::mutex>>();
std::lock_guard<std::mutex> g(**mutexPtr);
auto* buffer =
ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CPU__")->GetMutableTensor(CPU);
auto* buffer = BlobGetMutableTensor(
ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CPU__"), CPU);
f(buffer);
}
}

View File

@ -20,8 +20,8 @@ void runWithSharedBuffer<CUDAContext>(
auto* mutexPtr = mutexBlob->GetMutable<std::unique_ptr<std::mutex>>();
std::lock_guard<std::mutex> g(**mutexPtr);
auto* buffer =
ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA__")->GetMutableTensor(CUDA);
auto* buffer = BlobGetMutableTensor(
ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA__"), CUDA);
f(buffer);
}
}

View File

@ -17,7 +17,7 @@ void AddConstInput(const vector<int64_t>& shape,
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
math::Set<float, CPUContext>(
tensor->size(), value, tensor->template mutable_data<float>(), &context);
@ -29,7 +29,7 @@ void AddNoiseInput(const vector<int64_t>& shape,
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
math::RandGaussian<float, CPUContext>(

View File

@ -1428,7 +1428,7 @@ class TreeCursorSerializer : public BlobSerializerBase {
// serialize offsets as a tensor
if (cursor->offsets.size() > 0) {
Blob offsets_blob;
auto* offsets = offsets_blob.GetMutableTensor(CPU);
auto* offsets = BlobGetMutableTensor(&offsets_blob, CPU);
offsets->Resize(cursor->offsets.size());
std::copy(
cursor->offsets.begin(),

View File

@ -150,7 +150,7 @@ bool CuDNNDropoutOp::DoRunWithType() {
// Reshape tensor descriptors if necessary
if (X.dims() != cudnn_input_dims_ && !is_test_) {
CAFFE_ENFORCE(scratch_blob_);
Tensor* states = scratch_blob_->GetMutableTensor(CUDA);
Tensor* states = BlobGetMutableTensor(scratch_blob_, CUDA);
cudnn_input_dims_ = X.dims();
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
data_desc_,

View File

@ -19,7 +19,7 @@ void FillTensor(
const std::vector<int64_t>& shape,
const std::vector<I_Type>& values) {
auto* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(Context::GetDeviceType());
auto* tensor = BlobGetMutableTensor(blob, Context::GetDeviceType());
tensor->Resize(shape);
auto* mutable_data = tensor->template mutable_data<O_Type>();
const O_Type* data = reinterpret_cast<const O_Type*>(values.data());

View File

@ -18,7 +18,7 @@ static void AddConstInput(
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
math::Set<float, CPUContext>(
tensor->size(), value, tensor->template mutable_data<float>(), &context);
@ -34,7 +34,7 @@ static void AddLinSpacedInput(
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
EigenVectorMap<float> tensor_vec(
tensor->template mutable_data<float>(), tensor->size());
@ -51,7 +51,7 @@ static void AddInput(
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
EigenVectorMap<float> tensor_vec(
tensor->template mutable_data<float>(), tensor->size());

View File

@ -353,7 +353,7 @@ class IndexSerializer : public BlobSerializerBase {
SerializationAcceptor acceptor) override {
auto& base = blob.template Get<std::unique_ptr<IndexBase>>();
Blob tensor_blob;
auto* tensor_out = tensor_blob.GetMutableTensor(CPU);
auto* tensor_out = BlobGetMutableTensor(&tensor_blob, CPU);
if (base->Type().Match<std::string>()) {
doStore<std::string>(base, tensor_out);

View File

@ -213,23 +213,23 @@ class ONNXWhileOp final : public Operator<Context> {
lcd_tensors_.clear();
for (int i = 2; i < body_net_def.external_input_size(); ++i) {
Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i));
Tensor* t = b->GetMutableTensor(Context::GetDeviceType());
Tensor* t = BlobGetMutableTensor(b, Context::GetDeviceType());
lcd_tensors_.push_back(t);
}
// First output is the iteration variable
auto* iteration_var_blob = loop_ws_->CreateBlob(
body_net_def.external_input(0));
iteration_var_ =
iteration_var_blob->GetMutableTensor(Context::GetDeviceType());
BlobGetMutableTensor(iteration_var_blob, Context::GetDeviceType());
input_condition_var_ =
loop_ws_->CreateBlob(body_net_def.external_input(1))
->GetMutableTensor(Context::GetDeviceType());
input_condition_var_ = BlobGetMutableTensor(
loop_ws_->CreateBlob(body_net_def.external_input(1)),
Context::GetDeviceType());
auto* condition_var_blob =
loop_ws_->CreateBlob(body_net_def.external_output(0));
condition_var_ =
condition_var_blob->GetMutableTensor(Context::GetDeviceType());
BlobGetMutableTensor(condition_var_blob, Context::GetDeviceType());
condition_var_->Resize(1);
condition_var_->template mutable_data<bool>();

View File

@ -15,7 +15,7 @@ void BlobToTensorDescriptor(
// Memory type
// We only allow weights to be CPU tensor for now
CAFFE_ENFORCE(
blob->IsTensorType(CPU),
BlobIsTensorType(*blob, CPU),
"Initialization blob ",
name,
" needs to be TensorCPU");

View File

@ -65,8 +65,8 @@ class GPUFallbackOpEx final : public Operator<CUDAContext> {
bool need_sync = false;
for (int i = 0; i < InputSize(); ++i) {
if (this->InputIsTensorType(i, CUDA)) {
local_input_blobs_[i]->GetMutableTensor(CPU)->CopyFrom(
Input(i), &context_);
BlobGetMutableTensor(local_input_blobs_[i], CPU)
->CopyFrom(Input(i), &context_);
need_sync = true;
} else {
VLOG(1) << "Input " << i << " is not TensorCUDA. Skipping copy.";
@ -95,7 +95,7 @@ class GPUFallbackOpEx final : public Operator<CUDAContext> {
continue;
}
CAFFE_ENFORCE(
local_output_blobs_[i]->IsTensorType(CPU),
BlobIsTensorType(*local_output_blobs_[i], CPU),
"GPU fallback op currently does not support non-TensorCPU "
"output type who needs copying.");
Output(i)->CopyFrom(local_output_blobs_[i]->template Get<TensorCPU>());

View File

@ -40,7 +40,7 @@ TEST(OperatorFallbackTest, IncrementByOneOp) {
for (int i = 0; i < 6; ++i) {
source_tensor.mutable_data<float>()[i] = i;
}
ws.CreateBlob("X")->GetMutableTensor(CPU)->CopyFrom(source_tensor);
BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(source_tensor);
unique_ptr<OperatorBase> op(CreateOperator(op_def, &ws));
EXPECT_TRUE(op.get() != nullptr);
EXPECT_TRUE(op->Run());
@ -64,7 +64,7 @@ TEST(OperatorFallbackTest, GPUIncrementByOneOp) {
for (int i = 0; i < 6; ++i) {
source_tensor.mutable_data<float>()[i] = i;
}
ws.CreateBlob("X")->GetMutableTensor(CUDA)->CopyFrom(source_tensor);
BlobGetMutableTensor(ws.CreateBlob("X"), CUDA)->CopyFrom(source_tensor);
unique_ptr<OperatorBase> op(CreateOperator(op_def, &ws));
EXPECT_TRUE(op.get() != nullptr);
EXPECT_TRUE(op->Run());

View File

@ -20,7 +20,7 @@ static void AddConstInput(
option.set_device_type(PROTO_CUDA);
CUDAContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CUDA);
auto* tensor = BlobGetMutableTensor(blob, CUDA);
tensor->Resize(shape);
math::Set<float, CUDAContext>(
tensor->size(), value, tensor->template mutable_data<float>(), &context);

View File

@ -43,11 +43,10 @@ class RecurrentNetworkBlobFetcherOp final : public Operator<Context> {
prefix_ + std::string("_") + blob_name + caffe2::to_string(i);
blob_names_vector.push_back(newBlobName);
ws_->CreateBlob(newBlobName)
->GetMutableTensor(CPU)
BlobGetMutableTensor(ws_->CreateBlob(newBlobName), CPU)
->ResizeLike(currentTensor);
auto type = Context::GetDeviceType();
auto* newTensor = ws_->GetBlob(newBlobName)->GetMutableTensor(type);
auto* newTensor = BlobGetMutableTensor(ws_->GetBlob(newBlobName), type);
newTensor->CopyFrom(currentTensor);
}
}

View File

@ -111,10 +111,10 @@ class RecurrentNetworkExecutorBase {
// the forward-only mode.
std::string this_timestep_blob =
timestep_blob_ + "_rnnexec_t" + caffe2::to_string(t);
ws->CreateBlob(this_timestep_blob)->GetMutableTensor(CPU)->Resize(1);
BlobGetMutableTensor(ws->CreateBlob(this_timestep_blob), CPU)->Resize(1);
auto b = ws->GetBlob(this_timestep_blob);
CAFFE_ENFORCE(b);
b->GetMutableTensor(CPU)->template mutable_data<int32_t>()[0] = t;
BlobGetMutableTensor(b, CPU)->template mutable_data<int32_t>()[0] = t;
// Copy the operators from template
for (auto& template_rnn_op : timestep_ops_template_) {

View File

@ -52,10 +52,11 @@ struct CAFFE2_API ScratchWorkspaces {
};
inline void UpdateTimestepBlob(Workspace* ws, std::string blob_name, int t) {
ws->CreateBlob(blob_name)->GetMutableTensor(CPU)->Resize(1);
BlobGetMutableTensor(ws->CreateBlob(blob_name), CPU)->Resize(1);
auto timestepBlob = ws->GetBlob(blob_name);
CAFFE_ENFORCE(timestepBlob);
timestepBlob->GetMutableTensor(CPU)->template mutable_data<int32_t>()[0] = t;
BlobGetMutableTensor(timestepBlob, CPU)->template mutable_data<int32_t>()[0] =
t;
}
CAFFE2_API std::map<string, string> GetRecurrentMapping(
@ -71,8 +72,9 @@ void applyOffsetAlias(
<< " at offset: " << oc.offset;
auto srcBlob = ws->GetBlob(oc.src);
CAFFE_ENFORCE(srcBlob);
auto* src = srcBlob->GetMutableTensor(Context::GetDeviceType());
auto* dst = ws->GetBlob(oc.dst)->GetMutableTensor(Context::GetDeviceType());
auto* src = BlobGetMutableTensor(srcBlob, Context::GetDeviceType());
auto* dst =
BlobGetMutableTensor(ws->GetBlob(oc.dst), Context::GetDeviceType());
auto timestep = src->size() / src->dim(0);
auto dims = src->dims();
const int32_t startDstTimestep =
@ -113,7 +115,7 @@ void initializeRecurrentInput(
Context* context) {
auto stateBlob = ws->GetBlob(rc.state);
CAFFE_ENFORCE(stateBlob);
auto* state = stateBlob->GetMutableTensor(Context::GetDeviceType());
auto* state = BlobGetMutableTensor(stateBlob, Context::GetDeviceType());
auto inputBlob = ws->GetBlob(rc.input);
CAFFE_ENFORCE(inputBlob);
@ -660,7 +662,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
auto gBlob = sharedWs_->GetBlob(param.grad);
CAFFE_ENFORCE(gBlob);
auto* g = gBlob->GetMutableTensor(Context::GetDeviceType());
auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
g->ResizeLike(p);
math::Set<T, Context>(
g->size(),
@ -676,7 +678,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
auto gBlob = sharedWs_->CreateBlob(rg.grad);
CAFFE_ENFORCE(gBlob);
auto* g = gBlob->GetMutableTensor(Context::GetDeviceType());
auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
g->ResizeLike(p);
CAFFE_ENFORCE_EQ(g->ndim(), 3);
const auto timestep = g->size() / g->dim(0);
@ -703,7 +705,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
<< ". Size: " << Input(gradientInputIndex).size();
auto pGradientBlob = sharedWs_->GetBlob(gradientName);
CAFFE_ENFORCE(pGradientBlob);
auto* g = pGradientBlob->GetMutableTensor(Context::GetDeviceType());
auto* g = BlobGetMutableTensor(pGradientBlob, Context::GetDeviceType());
g->ResizeLike(Input(gradientInputIndex));
g->template mutable_data<T>();
}
@ -717,7 +719,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
<< rg.lastExternalGrad << " for final time step (sep. blob)";
auto gBlob = sharedWs_->GetBlob(rg.grad);
CAFFE_ENFORCE(gBlob);
auto* g = gBlob->GetMutableTensor(Context::GetDeviceType());
auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
auto oglastBlob = sharedWs_->GetBlob(rg.lastExternalGrad);
CAFFE_ENFORCE(oglastBlob);
@ -779,7 +781,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
T* output_data = Output(outputIdx)->template mutable_data<T>();
auto pBlob = sharedWs_->GetBlob(recurrentGradients_[i].grad);
CAFFE_ENFORCE(pBlob);
auto* p = pBlob->GetMutableTensor(Context::GetDeviceType());
auto* p = BlobGetMutableTensor(pBlob, Context::GetDeviceType());
if (Input(inputId).ndim() >= 2) {
// Gradient states blob should live. And if it gets changed by the

View File

@ -18,7 +18,7 @@ void AddConstInput(
Context* context,
Workspace* ws) {
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(Context::GetDeviceType());
auto* tensor = BlobGetMutableTensor(blob, Context::GetDeviceType());
tensor->Resize(shape);
math::Set<float, Context>(
tensor->size(), value, tensor->template mutable_data<float>(), context);
@ -39,7 +39,7 @@ void AddInput<CPUContext>(
const string& name,
Workspace* ws) {
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
EigenVectorMap<float> tensor_vec(
tensor->template mutable_data<float>(), tensor->size());
@ -57,7 +57,7 @@ void AddInput<CUDAContext>(
tmp_vec.array() = utils::AsEArrXt(values);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CUDA);
auto* tensor = BlobGetMutableTensor(blob, CUDA);
tensor->CopyFrom(tmp);
}

View File

@ -9,7 +9,7 @@ class StringJoinOpTest : public testing::Test {
public:
bool runOp(const TensorCPU& input) {
auto* blob = ws_.CreateBlob("X");
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->ResizeLike(input);
tensor->ShareData(input);
@ -26,7 +26,7 @@ class StringJoinOpTest : public testing::Test {
const std::string* checkAndGetOutput(int outputSize) {
const auto* output = ws_.GetBlob("Y");
EXPECT_NE(output, nullptr);
EXPECT_TRUE(output->IsTensorType(CPU));
EXPECT_TRUE(BlobIsTensorType(*output, CPU));
const auto& outputTensor = output->Get<TensorCPU>();
EXPECT_EQ(outputTensor.ndim(), 1);
EXPECT_EQ(outputTensor.dim(0), outputSize);
@ -42,7 +42,7 @@ TEST_F(StringJoinOpTest, testString1DJoin) {
std::vector<std::string> input = {"a", "xx", "c"};
auto blob = caffe2::make_unique<Blob>();
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob.get(), CPU);
tensor->Resize(input.size());
auto* data = tensor->template mutable_data<std::string>();
for (int i = 0; i < input.size(); ++i) {
@ -62,7 +62,7 @@ TEST_F(StringJoinOpTest, testString2DJoin) {
{"dd", "ee", "ff"}};
auto blob = caffe2::make_unique<Blob>();
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob.get(), CPU);
tensor->Resize(input.size(), input[0].size());
auto* data = tensor->template mutable_data<std::string>();
for (int i = 0; i < input.size(); ++i) {
@ -82,7 +82,7 @@ TEST_F(StringJoinOpTest, testFloat1DJoin) {
std::vector<float> input = {3.90f, 5.234f, 8.12f};
auto blob = caffe2::make_unique<Blob>();
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob.get(), CPU);
tensor->Resize(input.size());
auto* data = tensor->template mutable_data<float>();
for (int i = 0; i < input.size(); ++i) {
@ -102,7 +102,7 @@ TEST_F(StringJoinOpTest, testFloat2DJoin) {
{4.67f, 5.90f, 6.32f}};
auto blob = caffe2::make_unique<Blob>();
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob.get(), CPU);
tensor->Resize(input.size(), input[0].size());
auto* data = tensor->template mutable_data<float>();
for (int i = 0; i < input.size(); ++i) {
@ -122,7 +122,7 @@ TEST_F(StringJoinOpTest, testLong2DJoin) {
std::vector<std::vector<int64_t>> input = {{100, 200}, {1000, 2000}};
auto blob = caffe2::make_unique<Blob>();
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob.get(), CPU);
tensor->Resize(input.size(), input[0].size());
auto* data = tensor->template mutable_data<int64_t>();
for (int i = 0; i < input.size(); ++i) {

View File

@ -82,10 +82,10 @@ class PackedInt8BGRANHWCToNCHWCStylizerPreprocessOp
auto defaultNoiseSize = OperatorBase::GetSingleArgument<int>(
"noise_size", 491 /* prime to avoid artifacts */);
if (!noiseBlob->IsTensorType(CPU)) {
if (!BlobIsTensorType(*noiseBlob, CPU)) {
// Initialize random noise on first use.
// Cache it to maintain temporal consistency.
auto* t = noiseBlob->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(noiseBlob, CPU);
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
// Noise space is larger for vectorized code due to the

View File

@ -56,7 +56,7 @@ bool TensorProtosDBInput<Context>::Prefetch() {
protos.mutable_protos(i)->clear_device_detail();
}
deserializer.Deserialize(
protos.protos(i), prefetched_blobs_[i].GetMutableTensor(CPU));
protos.protos(i), BlobGetMutableTensor(&prefetched_blobs_[i], CPU));
}
} else {
vector<Tensor> temp_tensors;
@ -74,11 +74,11 @@ bool TensorProtosDBInput<Context>::Prefetch() {
vector<int> dims(
protos.protos(i).dims().begin(), protos.protos(i).dims().end());
dims.insert(dims.begin(), batch_size_);
prefetched_blobs_[i].GetMutableTensor(CPU)->Resize(dims);
BlobGetMutableTensor(&prefetched_blobs_[i], CPU)->Resize(dims);
}
}
for (int i = 0; i < protos.protos_size(); ++i) {
TensorCPU* dst = prefetched_blobs_[i].GetMutableTensor(CPU);
TensorCPU* dst = BlobGetMutableTensor(&prefetched_blobs_[i], CPU);
TensorCPU& src = temp_tensors[i];
if (protos.protos(i).has_device_detail()) {
protos.mutable_protos(i)->clear_device_detail();

View File

@ -52,7 +52,7 @@ class TTLinearOp final : public Operator<Context> {
int cores_idx = 0;
// Temporary buffer to facilitate multiplication of TT-cores with input
auto Y_buf = Y_temp_->GetMutableTensor(Context::GetDeviceType());
auto Y_buf = BlobGetMutableTensor(Y_temp_.get(), Context::GetDeviceType());
Y_buf->ResizeLike(X);
Y_buf->CopyFrom(X);

View File

@ -19,7 +19,7 @@ static void AddConstInput(
option.set_device_type(PROTO_CUDA);
CUDAContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CUDA);
auto* tensor = BlobGetMutableTensor(blob, CUDA);
tensor->Resize(shape);
math::Set<float, CUDAContext>(
tensor->size(), value, tensor->template mutable_data<float>(), &context);

View File

@ -16,7 +16,7 @@ static void AddConstInput(
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
math::Set<float, CPUContext>(
tensor->size(), value, tensor->template mutable_data<float>(), &context);

View File

@ -44,10 +44,10 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
CAFFE_ENFORCE(
bnInputs.size() >= 5, "Invalid batch normalization input size");
#define EXPOSE_TENSOR_DATA(name, index, inputs) \
auto name = repr::nn::get<repr::Tensor>(inputs[index]); \
assert(ws->HasBlob(name->getName()) && "Blob not in workspace"); \
auto name##Tensor = ws->GetBlob(name->getName())->GetMutableTensor(CPU); \
#define EXPOSE_TENSOR_DATA(name, index, inputs) \
auto name = repr::nn::get<repr::Tensor>(inputs[index]); \
assert(ws->HasBlob(name->getName()) && "Blob not in workspace"); \
auto name##Tensor = BlobGetMutableTensor(ws->GetBlob(name->getName()), CPU); \
auto name##Data = name##Tensor->mutable_data<float>();
EXPOSE_TENSOR_DATA(filter, 1, convInputs);
@ -76,7 +76,7 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
nn->dataFlow.createEdge(convBiasNode, convNode);
auto* blob = ws->CreateBlob(convBiasName);
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
// Get output channel
size_t c = filterTensor->dim32(0);

View File

@ -173,7 +173,7 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOp(
// Feed into workspace as CPU Tensors
auto* blob = ws->CreateBlob(t.name());
auto* cpu_tensor = blob->GetMutableTensor(CPU);
auto* cpu_tensor = BlobGetMutableTensor(blob, CPU);
std::vector<int64_t> dims;
for(const auto& d : t.dims()) {
dims.push_back(d);

View File

@ -10,14 +10,14 @@ void enforceIsTensor(Workspace* ws, const std::string& name) {
auto blob = ws->GetBlob(name);
CAFFE_ENFORCE(blob, "Blob does not exist: ", name);
CAFFE_ENFORCE(
blob->IsTensorType(CPU), "Blob is not a CPU Tensor: ", name);
BlobIsTensorType(*blob, CPU), "Blob is not a CPU Tensor: ", name);
}
TensorCPU* getTensor(Workspace* ws, const std::string& name) {
enforceIsTensor(ws, name);
auto* blob = ws->GetBlob(name);
CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
return blob->GetMutableTensor(CPU);
return BlobGetMutableTensor(blob, CPU);
}
void shareInputTensor(
@ -60,7 +60,7 @@ Predictor::Predictor(PredictorConfig config) : config_(std::move(config)) {
for (const auto& name : config_.predict_net->external_input()) {
if (!initialized.count(name)) {
auto* blob = config_.ws->CreateBlob(name);
blob->GetMutableTensor(CPU);
BlobGetMutableTensor(blob, CPU);
}
}
CAFFE_ENFORCE(config_.ws->CreateNet(config_.predict_net));

View File

@ -135,7 +135,7 @@ std::unique_ptr<Blob> randomTensor(
const std::vector<int64_t>& dims,
CPUContext* ctx) {
auto blob = make_unique<Blob>();
auto* t = blob->GetMutableTensor(CPU);
auto* t = BlobGetMutableTensor(blob.get(), CPU);
t->Resize(dims);
math::RandUniform<float, CPUContext>(
t->size(), -1.0, 1.0, t->template mutable_data<float>(), ctx);
@ -180,7 +180,7 @@ TEST_F(PredictorTest, SimpleBatchSized) {
auto inputData = randomTensor({1, 4}, ctx_.get());
Predictor::TensorList input;
input.emplace_back(CPU);
auto tensor = inputData->GetMutableTensor(CPU);
auto tensor = BlobGetMutableTensor(inputData.get(), CPU);
input.back().ResizeLike(*tensor);
input.back().ShareData(*tensor);
Predictor::TensorList output;
@ -196,7 +196,7 @@ TEST_F(PredictorTest, SimpleBatchSizedMapInput) {
auto inputData = randomTensor({1, 4}, ctx_.get());
Predictor::TensorMap input;
auto iter = input.emplace("data", Tensor(CPU));
auto tensor = inputData->GetMutableTensor(CPU);
auto tensor = BlobGetMutableTensor(inputData.get(), CPU);
iter.first->second.ResizeLike(*tensor);
iter.first->second.ShareData(*tensor);

View File

@ -328,7 +328,7 @@ void addObjectMethods(py::module& m) {
})
.def(
"tensor",
[](Blob* blob) { return py::cast(blob->GetMutableTensor(CPU)); },
[](Blob* blob) { return py::cast(BlobGetMutableTensor(blob, CPU)); },
py::return_value_policy::reference_internal)
.def(
"_feed",

View File

@ -234,7 +234,7 @@ class TensorFeeder : public BlobFeederBase {
FeedTensor(
option,
original_array,
blob->GetMutableTensor(Context::GetDeviceType()));
BlobGetMutableTensor(blob, Context::GetDeviceType()));
}
};
@ -366,31 +366,32 @@ class PythonOpBase : public Operator<Context> {
// make sure output blob is initialized before creating the binding
if (forced_cpu_outputs_.count(i)) {
blob->GetMutableTensor(Context::GetDeviceType());
BlobGetMutableTensor(blob, Context::GetDeviceType());
} else {
blob->GetMutableTensor(Context::GetDeviceType());
BlobGetMutableTensor(blob, Context::GetDeviceType());
}
py::object py_obj;
if (blob->template IsType<Tensor>()) {
if (use_dlpack) {
DLPackWrapper<CPUContext> wrapper(
blob->GetMutableTensor(Context::GetDeviceType()), cpu_option);
BlobGetMutableTensor(blob, Context::GetDeviceType()),
cpu_option);
py_obj = py::cast(wrapper, py::return_value_policy::copy);
} else {
py_obj = py::cast(
blob->GetMutableTensor(Context::GetDeviceType()),
BlobGetMutableTensor(blob, Context::GetDeviceType()),
py::return_value_policy::reference);
}
} else {
if (use_dlpack) {
DLPackWrapper<Context> wrapper(
blob->GetMutableTensor(Context::GetDeviceType()),
BlobGetMutableTensor(blob, Context::GetDeviceType()),
this->device_option());
py_obj = py::cast(wrapper, py::return_value_policy::copy);
} else {
py_obj = py::cast(
blob->GetMutableTensor(Context::GetDeviceType()),
BlobGetMutableTensor(blob, Context::GetDeviceType()),
py::return_value_policy::reference);
}
}

View File

@ -163,8 +163,8 @@ public:
DeviceOption cpu_option(option);
cpu_option.set_device_type(DeviceTypeProto::PROTO_CPU);
TensorFeeder<CPUContext> cpu_tensor_feeder;
cpu_tensor_feeder.FeedTensor(cpu_option, original_array,
blob->GetMutableTensor(CPU));
cpu_tensor_feeder.FeedTensor(
cpu_option, original_array, BlobGetMutableTensor(blob, CPU));
}
} catch (ideep::error &e) {
LOG(ERROR) << "IDEEP error: " << e.message;

View File

@ -19,7 +19,7 @@ void AddNoiseInput(
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
math::RandGaussian<float, CPUContext>(

View File

@ -231,11 +231,12 @@ bool NNPACKConvOp::RunOnDeviceWithOrderNCHW() {
(transformedFilterSize + sizeof(float) - 1) / sizeof(float);
for (auto g = 0; g < group_; g++) {
transformedFilters_[g] = ws_->CreateBlob(
"__transformed_kernel_" +
to_string(__sync_fetch_and_add(
&precomputed_transform_id, 1)))
->GetMutableTensor(CPU);
transformedFilters_[g] = BlobGetMutableTensor(
ws_->CreateBlob(
"__transformed_kernel_" +
to_string(
__sync_fetch_and_add(&precomputed_transform_id, 1))),
CPU);
transformedFilters_[g]->Resize(transformedFilterElements);
status = nnp_convolution_inference(

View File

@ -19,7 +19,7 @@ void AddNoiseInput(
DeviceOption option;
CPUContext context(option);
Blob* blob = ws->CreateBlob(name);
auto* tensor = blob->GetMutableTensor(CPU);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
math::RandGaussian<float, CPUContext>(

View File

@ -26,13 +26,13 @@ TEST(MathROCBLASTest, GemmNoTransNoTrans) {
vector<int> shapeX{5, 10};
vector<int> shapeW{10, 6};
vector<int> shapeY{5, 6};
auto* tensorX = blobX->GetMutableTensor(HIP);
auto* tensorX = BlobGetMutableTensor(blobX, HIP);
tensorX->Resize(shapeX);
auto* tensorW = blobW->GetMutableTensor(HIP);
auto* tensorW = BlobGetMutableTensor(blobW, HIP);
tensorW->Resize(shapeW);
auto* tensorY = blobY->GetMutableTensor(HIP);
auto* tensorY = BlobGetMutableTensor(blobY, HIP);
tensorY->Resize(shapeY);
auto* tensorY_host = blobY_host->GetMutableTensor(CPU);
auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU);
tensorY_host->Resize(shapeY);
EXPECT_EQ(tensorX->size(), 50);
@ -126,13 +126,13 @@ TEST(MathROCBLASTest, GemmNoTransTrans) {
vector<int> shapeX{5, 10};
vector<int> shapeW{6, 10};
vector<int> shapeY{5, 6};
auto* tensorX = blobX->GetMutableTensor(HIP);
auto* tensorX = BlobGetMutableTensor(blobX, HIP);
tensorX->Resize(shapeX);
auto* tensorW = blobW->GetMutableTensor(HIP);
auto* tensorW = BlobGetMutableTensor(blobW, HIP);
tensorW->Resize(shapeW);
auto* tensorY = blobY->GetMutableTensor(HIP);
auto* tensorY = BlobGetMutableTensor(blobY, HIP);
tensorY->Resize(shapeY);
auto* tensorY_host = blobY_host->GetMutableTensor(CPU);
auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU);
tensorY_host->Resize(shapeY);
EXPECT_EQ(tensorX->size(), 50);
@ -225,13 +225,13 @@ TEST(MathROCBLASTest, GemvNoTrans) {
vector<int> shapeA{5, 10};
vector<int> shapeX{10};
vector<int> shapeY{5};
auto* tensorA = blobA->GetMutableTensor(HIP);
auto* tensorA = BlobGetMutableTensor(blobA, HIP);
tensorA->Resize(shapeA);
auto* tensorX = blobX->GetMutableTensor(HIP);
auto* tensorX = BlobGetMutableTensor(blobX, HIP);
tensorX->Resize(shapeX);
auto* tensorY = blobY->GetMutableTensor(HIP);
auto* tensorY = BlobGetMutableTensor(blobY, HIP);
tensorY->Resize(shapeY);
auto* tensorY_host = blobY_host->GetMutableTensor(CPU);
auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU);
tensorY_host->Resize(shapeY);
EXPECT_EQ(tensorA->size(), 50);
@ -315,13 +315,13 @@ TEST(MathROCBLASTest, GemvTrans) {
vector<int> shapeA{6, 10};
vector<int> shapeX{6};
vector<int> shapeY{10};
auto* tensorA = blobA->GetMutableTensor(HIP);
auto* tensorA = BlobGetMutableTensor(blobA, HIP);
tensorA->Resize(shapeA);
auto* tensorX = blobX->GetMutableTensor(HIP);
auto* tensorX = BlobGetMutableTensor(blobX, HIP);
tensorX->Resize(shapeX);
auto* tensorY = blobY->GetMutableTensor(HIP);
auto* tensorY = BlobGetMutableTensor(blobY, HIP);
tensorY->Resize(shapeY);
auto* tensorY_host = blobY_host->GetMutableTensor(CPU);
auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU);
tensorY_host->Resize(shapeY);
EXPECT_EQ(tensorA->size(), 60);

View File

@ -41,9 +41,9 @@ void executeGpuBinaryOpTest(
Blob* bloby = ws.CreateBlob("Y");
Blob* bloby_host = ws.CreateBlob("Y_host");
auto* tensorx0 = blobx0->GetMutableTensor(CUDA);
auto* tensorx1 = blobx1->GetMutableTensor(CUDA);
auto* tensory = bloby->GetMutableTensor(CUDA);
auto* tensorx0 = BlobGetMutableTensor(blobx0, CUDA);
auto* tensorx1 = BlobGetMutableTensor(blobx1, CUDA);
auto* tensory = BlobGetMutableTensor(bloby, CUDA);
vector<int> shapex0_vector{shapex0};
vector<int> shapex1_vector{shapex1};
@ -71,7 +71,7 @@ void executeGpuBinaryOpTest(
context.FinishDeviceComputation();
// Copy result to CPU so we can inspect it
auto* tensory_host = bloby_host->GetMutableTensor(CPU);
auto* tensory_host = BlobGetMutableTensor(bloby_host, CPU);
tensory_host->CopyFrom(*tensory, &context);
context.FinishDeviceComputation();
@ -94,7 +94,7 @@ TEST(MathUtilGPUTest, testAddStripedBatch) {
vector<int> shapex{33 * 9, 25};
vector<int> shapey{33, 25};
auto* tensorx = blobx->GetMutableTensor(CUDA);
auto* tensorx = BlobGetMutableTensor(blobx, CUDA);
tensorx->Resize(shapex);
int stripe = 33 * 25;
vector<float> tot(33, 0.0);
@ -110,7 +110,7 @@ TEST(MathUtilGPUTest, testAddStripedBatch) {
}
}
auto* tensory = bloby->GetMutableTensor(CUDA);
auto* tensory = BlobGetMutableTensor(bloby, CUDA);
tensory->Resize(shapey);
math::Set<float, CUDAContext>(
stripe, 0.0, tensory->mutable_data<float>(), &context);
@ -125,7 +125,7 @@ TEST(MathUtilGPUTest, testAddStripedBatch) {
context.FinishDeviceComputation();
// Copy result to CPU so we can inspect it
auto* tensory_host = bloby_host->GetMutableTensor(CPU);
auto* tensory_host = BlobGetMutableTensor(bloby_host, CPU);
tensory_host->CopyFrom(*tensory, &context);
context.FinishDeviceComputation();
@ -258,9 +258,9 @@ class GemmBatchedGPUTest
Blob* X_blob = ws_.CreateBlob("X");
Blob* W_blob = ws_.CreateBlob("W");
Blob* Y_blob = ws_.CreateBlob("Y");
X_ = X_blob->GetMutableTensor(CUDA);
W_ = W_blob->GetMutableTensor(CUDA);
Y_ = Y_blob->GetMutableTensor(CUDA);
X_ = BlobGetMutableTensor(X_blob, CUDA);
W_ = BlobGetMutableTensor(W_blob, CUDA);
Y_ = BlobGetMutableTensor(Y_blob, CUDA);
X_->Resize(std::vector<int64_t>{3, 5, 10});
W_->Resize(std::vector<int64_t>{3, 6, 10});
Y_->Resize(std::vector<int64_t>{3, 5, 6});
@ -381,8 +381,8 @@ class ReduceTensorGPUTest : public testing::Test {
cuda_context_ = make_unique<CUDAContext>(option_);
Blob* blob_x = ws_.CreateBlob("X");
Blob* blob_y = ws_.CreateBlob("Y");
X_ = blob_x->GetMutableTensor(CUDA);
Y_ = blob_y->GetMutableTensor(CUDA);
X_ = BlobGetMutableTensor(blob_x, CUDA);
Y_ = BlobGetMutableTensor(blob_y, CUDA);
}
void SetUpData(
@ -402,7 +402,7 @@ class ReduceTensorGPUTest : public testing::Test {
void VerifyResult(const std::vector<float>& expected_output) {
Blob* blob_y_host = ws_.CreateBlob("Y_host");
auto* Y_host = blob_y_host->GetMutableTensor(CPU);
auto* Y_host = BlobGetMutableTensor(blob_y_host, CPU);
Y_host->CopyFrom(*Y_, cuda_context_.get());
cuda_context_->FinishDeviceComputation();
ASSERT_EQ(expected_output.size(), Y_host->size());
@ -664,8 +664,8 @@ class BroadcastGPUTest : public testing::Test {
cuda_context_ = make_unique<CUDAContext>(option_);
Blob* blob_x = ws_.CreateBlob("X");
Blob* blob_y = ws_.CreateBlob("Y");
X_ = blob_x->GetMutableTensor(CUDA);
Y_ = blob_y->GetMutableTensor(CUDA);
X_ = BlobGetMutableTensor(blob_x, CUDA);
Y_ = BlobGetMutableTensor(blob_y, CUDA);
}
void SetUpData(
@ -681,7 +681,7 @@ class BroadcastGPUTest : public testing::Test {
void VerifyResult(const std::vector<float>& expected_output) {
Blob* blob_y_host = ws_.CreateBlob("Y_host");
auto* Y_host = blob_y_host->GetMutableTensor(CPU);
auto* Y_host = BlobGetMutableTensor(blob_y_host, CPU);
Y_host->CopyFrom(*Y_, cuda_context_.get());
cuda_context_->FinishDeviceComputation();
ASSERT_EQ(expected_output.size(), Y_host->size());
@ -741,9 +741,9 @@ class MomentsGPUTest : public testing::Test {
Blob* blob_x = ws_.CreateBlob("X");
Blob* blob_mean = ws_.CreateBlob("mean");
Blob* blob_variance = ws_.CreateBlob("variance");
X_ = blob_x->GetMutableTensor(CUDA);
mean_ = blob_mean->GetMutableTensor(CUDA);
variance_ = blob_variance->GetMutableTensor(CUDA);
X_ = BlobGetMutableTensor(blob_x, CUDA);
mean_ = BlobGetMutableTensor(blob_mean, CUDA);
variance_ = BlobGetMutableTensor(blob_variance, CUDA);
}
void SetUpData(
@ -766,10 +766,10 @@ class MomentsGPUTest : public testing::Test {
const std::vector<float>& mean_data,
const std::vector<float>& variance_data) {
Blob* blob_mean_host = ws_.CreateBlob("mean_host");
auto* mean_host = blob_mean_host->GetMutableTensor(CPU);
auto* mean_host = BlobGetMutableTensor(blob_mean_host, CPU);
mean_host->CopyFrom(*mean_, cuda_context_.get());
Blob* blob_variance_host = ws_.CreateBlob("variance_host");
auto* variance_host = blob_variance_host->GetMutableTensor(CPU);
auto* variance_host = BlobGetMutableTensor(blob_variance_host, CPU);
variance_host->CopyFrom(*variance_, cuda_context_.get());
cuda_context_->FinishDeviceComputation();
@ -868,8 +868,8 @@ class TransposeGPUTest : public testing::Test {
cuda_context_ = make_unique<CUDAContext>(option_);
Blob* blob_x = ws_.CreateBlob("X");
Blob* blob_y = ws_.CreateBlob("Y");
X_ = blob_x->GetMutableTensor(CUDA);
Y_ = blob_y->GetMutableTensor(CUDA);
X_ = BlobGetMutableTensor(blob_x, CUDA);
Y_ = BlobGetMutableTensor(blob_y, CUDA);
}
void SetUpData(
@ -890,7 +890,7 @@ class TransposeGPUTest : public testing::Test {
void VerifyResult(const std::vector<float>& expected_output) {
Blob* blob_y_host = ws_.CreateBlob("Y_host");
auto* Y_host = blob_y_host->GetMutableTensor(CPU);
auto* Y_host = BlobGetMutableTensor(blob_y_host, CPU);
Y_host->CopyFrom(*Y_, cuda_context_.get());
cuda_context_->FinishDeviceComputation();
ASSERT_EQ(expected_output.size(), Y_host->size());