mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Removing some dependency edges from Blob to other caffe2 (#12043)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12043 Re-trying D9979976, this time with all call sites fixed. D9979976 got reverted because there was a call site that wasn't covered by sandcastle it seems. I fixed it and used 'grep' to ensure there aren't any more call sites in fbsource. Reviewed By: ezyang Differential Revision: D10026392 fbshipit-source-id: cd341514a8e53a40147ea0ee3e52f63bb6444157
This commit is contained in:
committed by
Facebook Github Bot
parent
94c513cc7f
commit
8f0db9bbbb
@ -163,7 +163,7 @@ void loadInput(
|
||||
CAFFE_THROW("Not support GPU on mobile.");
|
||||
#endif
|
||||
} else {
|
||||
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
if (input_type_list[i] == "uint8_t") {
|
||||
@ -200,7 +200,7 @@ void fillInputBlob(
|
||||
int protos_size = tensor_kv.second.protos_size();
|
||||
caffe2::TensorProto* tensor_proto =
|
||||
tensor_kv.second.mutable_protos(iteration % protos_size);
|
||||
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
if (tensor_proto->data_type() == caffe2::TensorProto::STRING) {
|
||||
int total_size = tensor_proto->string_data_size();
|
||||
for (size_t i = 0; i < total_size; i++) {
|
||||
@ -298,7 +298,7 @@ void writeOutput(
|
||||
#endif
|
||||
} else {
|
||||
writeTextOutput<caffe2::CPUContext, caffe2::TensorCPU>(
|
||||
workspace->GetBlob(name)->GetMutableTensor(caffe2::CPU),
|
||||
BlobGetMutableTensor(workspace->GetBlob(name), caffe2::CPU),
|
||||
output_prefix,
|
||||
name);
|
||||
}
|
||||
|
@ -137,7 +137,7 @@ int main(int argc, char** argv) {
|
||||
if (blob == nullptr) {
|
||||
blob = workspace->CreateBlob(input_names[i]);
|
||||
}
|
||||
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
if (input_type_list[i] == "uint8_t") {
|
||||
|
@ -12,7 +12,7 @@ namespace caffe2 {
|
||||
namespace gloo {
|
||||
|
||||
void signalFailure(Blob* status_blob, std::exception& /* unused */) {
|
||||
auto* res = status_blob->GetMutableTensor(CPU);
|
||||
auto* res = BlobGetMutableTensor(status_blob, CPU);
|
||||
res->Resize(1);
|
||||
res->template mutable_data<int32_t>()[0] = 1;
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ static void AddConstInput(const std::vector<int>& shape, const float value,
|
||||
option.set_device_type(PROTO_CUDA);
|
||||
CUDAContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CUDA);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CUDA);
|
||||
tensor->Resize(shape);
|
||||
math::Set<float, CUDAContext>(tensor->size(), value,
|
||||
tensor->mutable_data<float>(),
|
||||
|
@ -95,10 +95,10 @@ void BlobToTensorProto(
|
||||
}
|
||||
|
||||
// Set values
|
||||
if (blob->IsTensorType(CPU)) {
|
||||
if (BlobIsTensorType(*blob, CPU)) {
|
||||
const auto& cpu_tensor = blob->template Get<TensorCPU>();
|
||||
CPUTensorToTensorProto(cpu_tensor, t);
|
||||
} else if (blob->IsTensorType(CUDA)) {
|
||||
} else if (BlobIsTensorType(*blob, CUDA)) {
|
||||
const auto& cuda_tensor = blob->template Get<TensorCUDA>();
|
||||
const auto cpu_tensor = TensorCPU(cuda_tensor, context);
|
||||
context->FinishDeviceComputation();
|
||||
|
@ -6,16 +6,16 @@
|
||||
#include <typeinfo>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe2/core/blob_serializer_base.h"
|
||||
#include "caffe2/core/common.h"
|
||||
|
||||
#include <ATen/core/typeid.h>
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/tensor.h"
|
||||
#include "caffe2/core/typeid.h"
|
||||
#include "caffe2/proto/caffe2_pb.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
class Tensor;
|
||||
|
||||
/**
|
||||
* @brief Blob is a general container that hosts a typed pointer.
|
||||
*
|
||||
@ -50,15 +50,6 @@ class CAFFE2_API Blob final {
|
||||
return meta_.Match<T>();
|
||||
}
|
||||
|
||||
bool IsTensorType(DeviceType device_type) const {
|
||||
bool is_match = meta_.Match<Tensor>();
|
||||
auto* tensor = static_cast<Tensor*>(pointer_);
|
||||
if (is_match && tensor && tensor->GetDeviceType() == device_type) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the meta info of the blob.
|
||||
*/
|
||||
@ -109,9 +100,6 @@ class CAFFE2_API Blob final {
|
||||
std::is_default_constructible<T>::value,
|
||||
"GetMutable can't be called with non-default-constructible types. "
|
||||
"Try using specialized methods");
|
||||
static_assert(
|
||||
!std::is_same<T, Tensor>::value,
|
||||
"Use GetMutableTensor(DeviceType) instead");
|
||||
if (IsType<T>()) {
|
||||
return static_cast<T*>(pointer_);
|
||||
} else {
|
||||
@ -129,16 +117,6 @@ class CAFFE2_API Blob final {
|
||||
}
|
||||
}
|
||||
|
||||
inline Tensor* GetMutableTensor(DeviceType device_type) {
|
||||
if (IsTensorType(device_type)) {
|
||||
return static_cast<Tensor*>(pointer_);
|
||||
} else {
|
||||
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
|
||||
<< " DeviceType:" << device_type;
|
||||
return Reset<Tensor>(new Tensor(device_type));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the underlying object to the allocated one. The Blob then takes over
|
||||
* the ownership of the passed in pointer. If there is already an object in
|
||||
@ -248,5 +226,29 @@ inline void swap(Blob& lhs, Blob& rhs) {
|
||||
lhs.swap(rhs);
|
||||
}
|
||||
|
||||
inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) {
|
||||
bool is_match = blob.meta().Match<Tensor>();
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
const Tensor* tensor = &blob.Get<Tensor>();
|
||||
return tensor && tensor->GetDeviceType() == device_type;
|
||||
}
|
||||
|
||||
inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) {
|
||||
if (blob->IsType<Tensor>()) {
|
||||
Tensor* tensor = blob->GetMutable<Tensor>();
|
||||
if (tensor->GetDeviceType() == device_type) {
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
|
||||
// if we're here, then either Blob didn't hold a Tensor
|
||||
// or that Tensor had the wrong DeviceType.
|
||||
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
|
||||
<< " DeviceType:" << device_type;
|
||||
return blob->Reset<Tensor>(new Tensor(device_type));
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
#endif // CAFFE2_CORE_BLOB_H_
|
||||
|
@ -132,7 +132,7 @@ TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
|
||||
for (int i = 0; i < 6; ++i) { \
|
||||
cpu_tensor.mutable_data<TypeParam>()[i] = static_cast<TypeParam>(i); \
|
||||
} \
|
||||
blob.GetMutableTensor(CUDA)->CopyFrom(cpu_tensor); \
|
||||
BlobGetMutableTensor(&blob, CUDA)->CopyFrom(cpu_tensor); \
|
||||
string serialized = SerializeBlob(blob, "test"); \
|
||||
BlobProto proto; \
|
||||
CAFFE_ENFORCE(proto.ParseFromString(serialized)); \
|
||||
@ -149,7 +149,7 @@ TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
|
||||
} \
|
||||
Blob new_blob; \
|
||||
EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \
|
||||
EXPECT_TRUE(new_blob.IsTensorType(CUDA)); \
|
||||
EXPECT_TRUE(BlobIsTensorType(new_blob, CUDA)); \
|
||||
Tensor new_cpu_tensor(blob.Get<Tensor>(), CPU); \
|
||||
EXPECT_EQ(new_cpu_tensor.ndim(), 2); \
|
||||
EXPECT_EQ(new_cpu_tensor.dim(0), 2); \
|
||||
@ -199,7 +199,7 @@ TEST(TensorTest, TensorSerializationMultiDevices) {
|
||||
// Test if the restored blob is still of the same device.
|
||||
blob.Reset();
|
||||
EXPECT_NO_THROW(DeserializeBlob(serialized, &blob));
|
||||
EXPECT_TRUE(blob.IsTensorType(CUDA));
|
||||
EXPECT_TRUE(BlobIsTensorType(blob, CUDA));
|
||||
EXPECT_EQ(GetGPUIDForPointer(blob.Get<TensorCUDA>().data<float>()),
|
||||
gpu_id);
|
||||
// Test if we force the restored blob on a different device, we
|
||||
@ -207,7 +207,7 @@ TEST(TensorTest, TensorSerializationMultiDevices) {
|
||||
blob.Reset();
|
||||
proto.mutable_tensor()->mutable_device_detail()->set_cuda_gpu_id(0);
|
||||
EXPECT_NO_THROW(DeserializeBlob(proto.SerializeAsString(), &blob));
|
||||
EXPECT_TRUE(blob.IsTensorType(CUDA));
|
||||
EXPECT_TRUE(BlobIsTensorType(blob, CUDA));
|
||||
EXPECT_EQ(GetGPUIDForPointer(blob.Get<TensorCUDA>().data<float>()), 0);
|
||||
}
|
||||
}
|
||||
|
@ -363,7 +363,8 @@ void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
|
||||
auto tensor_proto = blob_proto.tensor();
|
||||
Deserialize(
|
||||
tensor_proto,
|
||||
blob->GetMutableTensor(
|
||||
BlobGetMutableTensor(
|
||||
blob,
|
||||
static_cast<DeviceType>(tensor_proto.device_detail().device_type())));
|
||||
}
|
||||
|
||||
|
@ -86,15 +86,15 @@ TEST(BlobTest, Blob) {
|
||||
int* int_unused CAFFE2_UNUSED = blob.GetMutable<int>();
|
||||
EXPECT_TRUE(blob.IsType<int>());
|
||||
EXPECT_FALSE(blob.IsType<BlobTestFoo>());
|
||||
EXPECT_FALSE(blob.IsTensorType(CPU));
|
||||
EXPECT_FALSE(BlobIsTensorType(blob, CPU));
|
||||
|
||||
BlobTestFoo* foo_unused CAFFE2_UNUSED = blob.GetMutable<BlobTestFoo>();
|
||||
EXPECT_TRUE(blob.IsType<BlobTestFoo>());
|
||||
EXPECT_FALSE(blob.IsType<int>());
|
||||
EXPECT_FALSE(blob.IsTensorType(CPU));
|
||||
EXPECT_FALSE(BlobIsTensorType(blob, CPU));
|
||||
|
||||
Tensor* tensor_unused CAFFE2_UNUSED = blob.GetMutableTensor(CPU);
|
||||
EXPECT_TRUE(blob.IsTensorType(CPU));
|
||||
Tensor* tensor_unused CAFFE2_UNUSED = BlobGetMutableTensor(&blob, CPU);
|
||||
EXPECT_TRUE(BlobIsTensorType(blob, CPU));
|
||||
EXPECT_FALSE(blob.IsType<BlobTestFoo>());
|
||||
EXPECT_FALSE(blob.IsType<int>());
|
||||
}
|
||||
@ -600,7 +600,7 @@ TEST(TensorDeathTest, CannotCastDownLargeDims) {
|
||||
#define TEST_SERIALIZATION_WITH_TYPE(TypeParam, field_name) \
|
||||
TEST(TensorTest, TensorSerialization_##TypeParam) { \
|
||||
Blob blob; \
|
||||
Tensor* tensor = blob.GetMutableTensor(CPU); \
|
||||
Tensor* tensor = BlobGetMutableTensor(&blob, CPU); \
|
||||
tensor->Resize(2, 3); \
|
||||
for (int i = 0; i < 6; ++i) { \
|
||||
tensor->mutable_data<TypeParam>()[i] = static_cast<TypeParam>(i); \
|
||||
@ -621,7 +621,7 @@ TEST(TensorDeathTest, CannotCastDownLargeDims) {
|
||||
} \
|
||||
Blob new_blob; \
|
||||
EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \
|
||||
EXPECT_TRUE(new_blob.IsTensorType(CPU)); \
|
||||
EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); \
|
||||
const TensorCPU& new_tensor = blob.Get<TensorCPU>(); \
|
||||
EXPECT_EQ(new_tensor.ndim(), 2); \
|
||||
EXPECT_EQ(new_tensor.dim(0), 2); \
|
||||
@ -634,7 +634,7 @@ TEST(TensorDeathTest, CannotCastDownLargeDims) {
|
||||
\
|
||||
TEST(EmptyTensorTest, TensorSerialization_##TypeParam) { \
|
||||
Blob blob; \
|
||||
TensorCPU* tensor = blob.GetMutableTensor(CPU); \
|
||||
TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU); \
|
||||
tensor->Resize(0, 3); \
|
||||
tensor->mutable_data<TypeParam>(); \
|
||||
string serialized = SerializeBlob(blob, "test"); \
|
||||
@ -650,7 +650,7 @@ TEST(TensorDeathTest, CannotCastDownLargeDims) {
|
||||
EXPECT_EQ(tensor_proto.field_name##_size(), 0); \
|
||||
Blob new_blob; \
|
||||
EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob)); \
|
||||
EXPECT_TRUE(new_blob.IsTensorType(CPU)); \
|
||||
EXPECT_TRUE(BlobIsTensorType(new_blob, CPU)); \
|
||||
const TensorCPU& new_tensor = blob.Get<TensorCPU>(); \
|
||||
EXPECT_EQ(new_tensor.ndim(), 2); \
|
||||
EXPECT_EQ(new_tensor.dim(0), 0); \
|
||||
@ -669,7 +669,7 @@ TEST_SERIALIZATION_WITH_TYPE(int64_t, int64_data)
|
||||
|
||||
TEST(TensorTest, TensorSerialization_CustomType) {
|
||||
Blob blob;
|
||||
TensorCPU* tensor = blob.GetMutableTensor(CPU);
|
||||
TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU);
|
||||
tensor->Resize(2, 3);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
tensor->mutable_data<BlobTestFoo>()[i].val = i;
|
||||
@ -681,7 +681,7 @@ TEST(TensorTest, TensorSerialization_CustomType) {
|
||||
EXPECT_EQ(proto.type(), "Tensor");
|
||||
Blob new_blob;
|
||||
EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob));
|
||||
EXPECT_TRUE(new_blob.IsTensorType(CPU));
|
||||
EXPECT_TRUE(BlobIsTensorType(new_blob, CPU));
|
||||
const TensorCPU& new_tensor = blob.Get<TensorCPU>();
|
||||
EXPECT_EQ(new_tensor.ndim(), 2);
|
||||
EXPECT_EQ(new_tensor.dim(0), 2);
|
||||
@ -696,7 +696,7 @@ TEST(TensorTest, TensorSerialization_CustomType) {
|
||||
TEST(TensorTest, Half) {
|
||||
const int64_t kSize = 3000000;
|
||||
Blob blob;
|
||||
TensorCPU* tensor = blob.GetMutableTensor(CPU);
|
||||
TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU);
|
||||
tensor->Resize(kSize);
|
||||
for (int i = 0; i < tensor->size(); ++i) {
|
||||
tensor->mutable_data<at::Half>()[i].x = i % 10000;
|
||||
@ -724,7 +724,7 @@ TEST(TensorTest, Half) {
|
||||
}
|
||||
Blob new_blob;
|
||||
EXPECT_NO_THROW(DeserializeBlob(serialized, &new_blob));
|
||||
EXPECT_TRUE(new_blob.IsTensorType(CPU));
|
||||
EXPECT_TRUE(BlobIsTensorType(new_blob, CPU));
|
||||
const TensorCPU& new_tensor = blob.Get<TensorCPU>();
|
||||
EXPECT_EQ(new_tensor.ndim(), 1);
|
||||
EXPECT_EQ(new_tensor.dim(0), kSize);
|
||||
@ -860,7 +860,7 @@ TYPED_TEST(TypedTensorTest, BigTensorSerialization) {
|
||||
{
|
||||
VLOG(1) << "Test begin";
|
||||
Blob blob;
|
||||
Tensor* tensor = blob.GetMutableTensor(CPU);
|
||||
Tensor* tensor = BlobGetMutableTensor(&blob, CPU);
|
||||
VLOG(1) << "Allocating blob";
|
||||
tensor->Resize(d1, d2);
|
||||
auto mutableData = tensor->mutable_data<TypeParam>();
|
||||
@ -903,7 +903,7 @@ TYPED_TEST(TypedTensorTest, BigTensorSerialization) {
|
||||
load_op->Run();
|
||||
VLOG(1) << "Reading blob from workspace";
|
||||
auto new_blob = ws.GetBlob("test");
|
||||
EXPECT_TRUE(new_blob->IsTensorType(CPU));
|
||||
EXPECT_TRUE(BlobIsTensorType(*new_blob, CPU));
|
||||
const auto& new_tensor = new_blob->Get<TensorCPU>();
|
||||
|
||||
EXPECT_EQ(new_tensor.ndim(), d1);
|
||||
@ -1030,7 +1030,7 @@ TEST(CustomChunkSize, BigTensorSerialization) {
|
||||
int64_t size = d1 * d2;
|
||||
|
||||
Blob blob;
|
||||
TensorCPU* tensor = blob.GetMutableTensor(CPU);
|
||||
TensorCPU* tensor = BlobGetMutableTensor(&blob, CPU);
|
||||
tensor->Resize(d1, d2);
|
||||
tensor->mutable_data<float>();
|
||||
std::mutex mutex;
|
||||
|
@ -122,7 +122,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
static_assert(
|
||||
std::is_same<T, Tensor>::value,
|
||||
"Output(int, DeviceType) is only available for Tensor");
|
||||
return outputs_.at(idx)->GetMutableTensor(type);
|
||||
return BlobGetMutableTensor(outputs_.at(idx), type);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -149,7 +149,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
}
|
||||
|
||||
inline bool InputIsTensorType(int idx, DeviceType device_type) {
|
||||
return inputs_.at(idx)->IsTensorType(device_type);
|
||||
return BlobIsTensorType(*inputs_.at(idx), device_type);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -162,7 +162,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
|
||||
}
|
||||
|
||||
inline bool OutputIsTensorType(int idx, DeviceType type) {
|
||||
return outputs_.at(idx)->IsTensorType(type);
|
||||
return BlobIsTensorType(*outputs_.at(idx), type);
|
||||
}
|
||||
|
||||
inline int InputSize() const {
|
||||
|
@ -131,7 +131,8 @@ struct WorkspaceIdInjector {
|
||||
"Integer overflow while calculating GLOBAL_WORKSPACE_ID blob");
|
||||
int32_t global_ws_id = (seq_++) + (static_cast<int32_t>(node_id) << 16);
|
||||
Blob* global_ws_id_blob = workspace->CreateLocalBlob(GLOBAL_WORKSPACE_ID);
|
||||
TensorCPU* global_ws_id_tensor = global_ws_id_blob->GetMutableTensor(CPU);
|
||||
TensorCPU* global_ws_id_tensor =
|
||||
BlobGetMutableTensor(global_ws_id_blob, CPU);
|
||||
global_ws_id_tensor->Resize();
|
||||
global_ws_id_tensor->template mutable_data<int32_t>()[0] = global_ws_id;
|
||||
VLOG(1) << "Adding " << GLOBAL_WORKSPACE_ID << " = " << global_ws_id;
|
||||
|
@ -151,7 +151,7 @@ class CAFFE2_API Workspace {
|
||||
auto* to_blob = CreateBlob(blob);
|
||||
CAFFE_ENFORCE(to_blob);
|
||||
const auto& from_tensor = from_blob->template Get<Tensor>();
|
||||
auto* to_tensor = to_blob->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* to_tensor = BlobGetMutableTensor(to_blob, Context::GetDeviceType());
|
||||
to_tensor->CopyFrom(from_tensor);
|
||||
}
|
||||
}
|
||||
|
@ -33,8 +33,9 @@ class IDEEPConcatOp final : public IDEEPOperator {
|
||||
if (OperatorBase::InputBlob(i).template IsType<itensor>()) {
|
||||
inputs.emplace_back(Input(i));
|
||||
} else {
|
||||
CAFFE_ENFORCE(OperatorBase::InputBlob(i).IsTensorType(CPU),
|
||||
"Expect cpu tensor if not itensor");
|
||||
CAFFE_ENFORCE(
|
||||
BlobIsTensorType(OperatorBase::InputBlob(i), CPU),
|
||||
"Expect cpu tensor if not itensor");
|
||||
auto& tensor_cpu = OperatorBase::Input<Tensor>(i, CPU);
|
||||
CAFFE_ENFORCE(tensor_cpu.dims().size() == 0 ||
|
||||
tensor_cpu.size_from_dim(0) == 0,
|
||||
|
@ -89,7 +89,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
local_input_blobs_[i]->Reset();
|
||||
}
|
||||
input_share_[i] = false;
|
||||
auto dtensor = local_input_blobs_[i]->GetMutableTensor(CPU);
|
||||
auto dtensor = BlobGetMutableTensor(local_input_blobs_[i], CPU);
|
||||
dtensor->Resize(input.get_dims());
|
||||
if (input.is_public_format()) {
|
||||
dtensor->ShareExternalPointer(
|
||||
@ -121,7 +121,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
continue;
|
||||
}
|
||||
CAFFE_ENFORCE(
|
||||
local_output_blobs_[i]->IsTensorType(CPU),
|
||||
BlobIsTensorType(*local_output_blobs_[i], CPU),
|
||||
"IDEEP fallback op currently does not support non-TensorCPU "
|
||||
"output type who needs copying.");
|
||||
const auto& src = local_output_blobs_[i]->template Get<TensorCPU>();
|
||||
@ -153,7 +153,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
VLOG(2) << "Output " << base_def_.output(i) << " as CPUTensor";
|
||||
Blob* dst = OperatorBase::OutputBlob(i);
|
||||
dst->Reset(new Tensor(CPU));
|
||||
auto dtensor = dst->GetMutableTensor(CPU);
|
||||
auto dtensor = BlobGetMutableTensor(dst, CPU);
|
||||
dtensor->Resize(src_dims);
|
||||
dtensor->ShareData(src);
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ class CopyIDEEPToCPUOp final : public IDEEPOperator {
|
||||
USE_IDEEP_DEF_ALIASES();
|
||||
bool RunOnDevice() override {
|
||||
const auto& input_blob = OperatorBase::InputBlob(0);
|
||||
if (input_blob.IsTensorType(CPU)) {
|
||||
if (BlobIsTensorType(input_blob, CPU)) {
|
||||
VLOG(2) << "Directing sharing of TensorCPU";
|
||||
const auto& X = OperatorBase::Input<Tensor>(0, CPU);
|
||||
auto* Y = OperatorBase::Output<Tensor>(0, CPU);
|
||||
|
@ -66,10 +66,10 @@ class MKLFallbackOp final : public Operator<MKLContext> {
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
if (OperatorBase::InputIsType<MKLMemory<float>>(i)) {
|
||||
OperatorBase::Input<MKLMemory<float>>(i).CopyTo(
|
||||
local_input_blobs_[i]->GetMutableTensor(CPU));
|
||||
BlobGetMutableTensor(local_input_blobs_[i], CPU));
|
||||
} else if (OperatorBase::InputIsType<MKLMemory<double>>(i)) {
|
||||
OperatorBase::Input<MKLMemory<double>>(i).CopyTo(
|
||||
local_input_blobs_[i]->GetMutableTensor(CPU));
|
||||
BlobGetMutableTensor(local_input_blobs_[i], CPU));
|
||||
} else {
|
||||
VLOG(1) << "Input " << i << " is not MKLMemory. Skipping copy.";
|
||||
// Note(jiayq): This removes a const but conceptually
|
||||
@ -93,7 +93,7 @@ class MKLFallbackOp final : public Operator<MKLContext> {
|
||||
continue;
|
||||
}
|
||||
CAFFE_ENFORCE(
|
||||
local_output_blobs_[i]->IsTensorType(CPU),
|
||||
BlobIsTensorType(*local_output_blobs_[i], CPU),
|
||||
"MKL fallback op currently does not support non-TensorCPU "
|
||||
"output type who needs copying.");
|
||||
const auto& src = local_output_blobs_[i]->template Get<TensorCPU>();
|
||||
|
@ -43,7 +43,7 @@ bool CopyFromGLOp<T>::RunOnDevice() {
|
||||
if (first_run_) {
|
||||
first_run_ = false;
|
||||
for (int i = 0; i < Inputs().size(); ++i) {
|
||||
auto* Y = OperatorBase::Outputs()[i]->GetMutableTensor(CPU);
|
||||
auto* Y = BlobGetMutableTensor(OperatorBase::Outputs()[i], CPU);
|
||||
Y->Resize(inputs_[i]->dims());
|
||||
Y->template mutable_data<float>();
|
||||
}
|
||||
@ -54,7 +54,7 @@ bool CopyFromGLOp<T>::RunOnDevice() {
|
||||
// GLTensor
|
||||
auto* X = inputs_[i].get();
|
||||
X->lazy_allocate(Xblob, second_run_, true);
|
||||
auto* Y = OperatorBase::Outputs()[i]->GetMutableTensor(CPU);
|
||||
auto* Y = BlobGetMutableTensor(OperatorBase::Outputs()[i], CPU);
|
||||
Timer timer;
|
||||
timer.Start();
|
||||
getTensorCPU(*X, *Y);
|
||||
|
@ -27,7 +27,7 @@ template<typename T = float>
|
||||
void PopulateCPUBlob(Workspace *ws, bool random, std::string name,
|
||||
std::vector<int> dims, int val = 1, int dist_shift = 0, float variance = 1) {
|
||||
Blob *blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(dims);
|
||||
T *t_data = tensor->mutable_data<T>();
|
||||
std::random_device rd;
|
||||
|
@ -489,13 +489,13 @@ class MPSCNNPackedInt8BGRANHWCToNCHWCStylizerPreprocessOp final
|
||||
"noise_size", 491 /* prime to avoid artifacts */);
|
||||
// Treaded as half4 in the kernel, so need half4 here.
|
||||
noiseSize = divRoundUp(noiseSize, 4) * 4;
|
||||
if (!noiseBlob->IsTensorType(CPU) ||
|
||||
if (!BlobIsTensorType(*noiseBlob, CPU) ||
|
||||
noiseBlob->Get<TensorCPU>().size() != noiseSize) {
|
||||
VLOG(2) << "Initializing stylizer with noise: " << noiseSize;
|
||||
caffe2::Timer rt;
|
||||
// Initialize random noise on first use.
|
||||
// Cache it to maintain temporal consistency.
|
||||
auto* t = noiseBlob->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(noiseBlob, CPU);
|
||||
t->Resize(noiseSize);
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
t->size(),
|
||||
|
@ -94,7 +94,7 @@ void testMPSCNN() {
|
||||
|
||||
Workspace ws;
|
||||
for (auto i = 0; i < N; ++i) {
|
||||
auto* t = ws.CreateBlob(cpu(i))->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob(cpu(i)), CPU);
|
||||
t->Resize(BS, C, H, W);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -152,7 +152,7 @@ void testMPSCNN() {
|
||||
|
||||
Workspace ws;
|
||||
for (auto i = 0; i < N; ++i) {
|
||||
auto* t = ws.CreateBlob(cpu(i))->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob(cpu(i)), CPU);
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
t->Resize(5);
|
||||
@ -210,7 +210,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNNormalizePlanarYUV Test: ";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batch_size, channels, 8, 13);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -218,14 +218,14 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
|
||||
t->Resize(1, channels);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
t->size(), 0, 1, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws.CreateBlob("stddev")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("stddev"), CPU);
|
||||
t->Resize(1, channels);
|
||||
CPUContext ctx;
|
||||
math::RandUniform<float, CPUContext>(
|
||||
@ -290,7 +290,7 @@ void testMPSCNN() {
|
||||
for (const auto dim : {10, 40}) {
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batchSize, channels, dim, dim);
|
||||
CPUContext ctx;
|
||||
// Too noisy.
|
||||
@ -299,7 +299,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(channels);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -310,7 +310,7 @@ void testMPSCNN() {
|
||||
// t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(channels);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -321,7 +321,7 @@ void testMPSCNN() {
|
||||
// t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws.CreateBlob("pw")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("pw"), CPU);
|
||||
t->Resize(prelu == PreluTy::SHARED ? 1 : channels);
|
||||
CPUContext ctx;
|
||||
// Too noisy.
|
||||
@ -409,7 +409,7 @@ void testMPSCNN() {
|
||||
Workspace ws;
|
||||
const auto channels = array ? 12 : 3;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batch_size, channels, 8, 13);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -417,7 +417,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(shared ? channels : 1);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -480,7 +480,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNSpatialBN Test: " << channels;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batch_size, channels, 8, 13);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -488,7 +488,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
for (const std::string name : {"scale", "bias", "mean", "var"}) {
|
||||
auto* t = ws.CreateBlob(name)->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob(name), CPU);
|
||||
t->Resize(channels);
|
||||
CPUContext ctx;
|
||||
// High mean to avoid var division by zero.
|
||||
@ -575,7 +575,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNFC Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batchSize, CIn, H, W);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -583,7 +583,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(COut, CIn * H * W);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -591,7 +591,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(COut);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -682,8 +682,8 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNPool Test: " << pool;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t =
|
||||
ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(
|
||||
ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batchSize, 8, 8, 13);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -784,7 +784,7 @@ void testMPSCNN() {
|
||||
std::vector<std::vector<size_t>>{{1, 3, 50, 80}, {1, 12, 50, 80}}) {
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(dims);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -860,7 +860,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNPreprocess Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(1, 8, 13, 4);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -869,7 +869,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
|
||||
t->Resize(3);
|
||||
CPUContext ctx;
|
||||
t->mutable_data<float>()[0] = 100;
|
||||
@ -940,7 +940,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNDeprocess Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(1, 3, 8, 24);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -949,7 +949,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
|
||||
t->Resize(3);
|
||||
CPUContext ctx;
|
||||
t->mutable_data<float>()[0] = 100;
|
||||
@ -999,7 +999,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNDeprocess Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(1, 3, 1280, 720);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -1008,7 +1008,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
|
||||
t->Resize(3);
|
||||
CPUContext ctx;
|
||||
t->mutable_data<float>()[0] = 30;
|
||||
@ -1072,7 +1072,8 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNConv Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t =
|
||||
BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batchSize, 12, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1080,7 +1081,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(8, 12, kernel_h, kernel_w);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1092,7 +1093,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(8);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1188,7 +1189,7 @@ void testMPSCNN() {
|
||||
Workspace ws;
|
||||
int output_channels = input_channels * channel_multiplier;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batchSize, input_channels, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1196,7 +1197,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(output_channels, 1, 3, 3);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1204,7 +1205,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(output_channels);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1275,7 +1276,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNConvRelu Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(1, 12, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1283,7 +1284,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(8, 12, 3, 3);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1291,7 +1292,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(8);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1385,7 +1386,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSConv Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(1, 12, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1393,7 +1394,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(8, 12, 3, 3);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1401,7 +1402,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(8);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1493,7 +1494,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSConv Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batchSize, C, 12, 16);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1501,7 +1502,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(M, C, K, K);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1509,7 +1510,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(M);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1607,7 +1608,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNConv Test - group";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batchSize, C, 12, 16);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1615,7 +1616,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(M, C / group, K, K);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1623,7 +1624,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(M);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1726,7 +1727,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNMul Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X0_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X0_cpu"), CPU);
|
||||
t->Resize(1, 12, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1734,7 +1735,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("X1_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X1_cpu"), CPU);
|
||||
t->Resize(72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1791,7 +1792,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNSub Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X0_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X0_cpu"), CPU);
|
||||
t->Resize(1, 12, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1799,7 +1800,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("X1_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X1_cpu"), CPU);
|
||||
t->Resize(72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1856,7 +1857,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSAdd Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X0_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X0_cpu"), CPU);
|
||||
t->Resize(1, 12, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1864,7 +1865,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("X1_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X1_cpu"), CPU);
|
||||
t->Resize(1, 12, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1921,7 +1922,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSAdd Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X0_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X0_cpu"), CPU);
|
||||
t->Resize(1, 12, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -1929,7 +1930,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("X1_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X1_cpu"), CPU);
|
||||
t->Resize(1, 12, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -2011,7 +2012,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNNeuron Test: " << n;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(1, 4, 12, 12);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -2065,7 +2066,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNDropout Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(1, 12, 57, 72);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -2136,7 +2137,7 @@ void testMPSCNN() {
|
||||
<< " - scale: " << scale;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(1, channels, 40, 40);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -2144,7 +2145,7 @@ void testMPSCNN() {
|
||||
}
|
||||
{
|
||||
// Use the batch-first encoding (n, [bbox])
|
||||
auto* t = ws.CreateBlob("R")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("R"), CPU);
|
||||
t->Resize(6, 5);
|
||||
for (auto i = 0; i < t->dim32(0); ++i) {
|
||||
t->mutable_data<float>()[5 * i + 0] = 0; // batch
|
||||
@ -2250,14 +2251,14 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNRoIWarp Test 2";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(1, 8, 40, 40);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
t->size(), 4, 2, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws.CreateBlob("R")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("R"), CPU);
|
||||
t->Resize(6, 4);
|
||||
for (auto i = 0; i < t->dim32(0); ++i) {
|
||||
t->mutable_data<float>()[4 * i + 0] = (i % 4 + 1) * 1.0 / scale;
|
||||
@ -2362,7 +2363,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNResizeNearestOp Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, 37, 89);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -2497,7 +2498,7 @@ void testMPSCNN() {
|
||||
vector<float> im_info{60, 80, 0.166667};
|
||||
vector<float> anchors{-38, -16, 53, 31, -120, -120, 135, 135};
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(num_images, A, H, W);
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
t->mutable_data<float>()[i] = scores[i];
|
||||
@ -2505,7 +2506,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("bbox_delta_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("bbox_delta_cpu"), CPU);
|
||||
t->Resize(num_images, 4 * A, H, W);
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
t->mutable_data<float>()[i] = bbx[i];
|
||||
@ -2513,7 +2514,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("im_info")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("im_info"), CPU);
|
||||
t->Resize(num_images, 3);
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
t->mutable_data<float>()[i] = im_info[i];
|
||||
@ -2521,7 +2522,7 @@ void testMPSCNN() {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("anchors")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("anchors"), CPU);
|
||||
t->Resize(A, 4);
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
t->mutable_data<float>()[i] = anchors[i];
|
||||
@ -2587,7 +2588,7 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSCNNSoftmax Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
// Only works for spatial dimension of (1, 1) - weird.
|
||||
t->Resize(batchSize, 12, 1, 1);
|
||||
CPUContext ctx;
|
||||
@ -2661,8 +2662,8 @@ void testMPSCNN() {
|
||||
LOG(INFO) << "MPSConvTranspose Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t =
|
||||
ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(
|
||||
ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batchSize, inputChannels, 8, 12);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -2675,7 +2676,7 @@ void testMPSCNN() {
|
||||
|
||||
{
|
||||
auto* t =
|
||||
ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(
|
||||
inputChannels,
|
||||
outputChannels,
|
||||
@ -2692,7 +2693,7 @@ void testMPSCNN() {
|
||||
|
||||
{
|
||||
auto* t =
|
||||
ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(outputChannels);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -2809,7 +2810,7 @@ void testMPSCNN() {
|
||||
<< batchSize;
|
||||
Workspace ws;
|
||||
for (auto i = 0; i < numInputs; ++i) {
|
||||
auto* t = ws.CreateBlob(cpu(i))->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob(cpu(i)), CPU);
|
||||
t->Resize(batchSize, array ? (i + 1) * 4 : 4, 10, 10);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -2891,7 +2892,7 @@ void testMPSCNN() {
|
||||
}
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(batchSize, inputChannels, 53, 47);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -2964,7 +2965,7 @@ void testMPSCNN() {
|
||||
<< numInputs << ", " << batchSize;
|
||||
Workspace ws;
|
||||
for (auto i = 0; i < numInputs; ++i) {
|
||||
auto* t = ws.CreateBlob(cpu(i))->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob(cpu(i)), CPU);
|
||||
t->Resize(batchSize, channelCount, 9, 17);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -3336,8 +3337,8 @@ void compareModels(const NetDef& initNet, NetDef predictNet) {
|
||||
Workspace cws;
|
||||
cws.RunNetOnce(initNet);
|
||||
{
|
||||
auto* t =
|
||||
cws.CreateBlob(predictNet.external_input(0))->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(
|
||||
cws.CreateBlob(predictNet.external_input(0)), CPU);
|
||||
t->Resize(1, 224, 224, 4);
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
t->mutable_data<uint8_t>()[i] = i % 225;
|
||||
@ -3348,8 +3349,8 @@ void compareModels(const NetDef& initNet, NetDef predictNet) {
|
||||
Workspace mws;
|
||||
mws.RunNetOnce(initNet);
|
||||
{
|
||||
auto* t =
|
||||
mws.CreateBlob(predictNet.external_input(0))->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(
|
||||
mws.CreateBlob(predictNet.external_input(0)), CPU);
|
||||
t->Resize(1, 224, 224, 4);
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
t->mutable_data<uint8_t>()[i] = i % 225;
|
||||
@ -3397,16 +3398,16 @@ void verifyRewrite(
|
||||
dumpDef(predictNet);
|
||||
dumpDef(metalPredictNet);
|
||||
|
||||
#define RUN_NET(ws, predictNet) \
|
||||
ws.RunNetOnce(initNet); \
|
||||
{ \
|
||||
auto* t = \
|
||||
ws.CreateBlob(predictNet.external_input(0))->GetMutableTensor(CPU); \
|
||||
t->Resize(inputDims); \
|
||||
CPUContext ctx; \
|
||||
math::RandGaussian<float, CPUContext>( \
|
||||
t->size(), 0, 1, t->mutable_data<float>(), &ctx); \
|
||||
} \
|
||||
#define RUN_NET(ws, predictNet) \
|
||||
ws.RunNetOnce(initNet); \
|
||||
{ \
|
||||
auto* t = BlobGetMutableTensor( \
|
||||
ws.CreateBlob(predictNet.external_input(0)), CPU); \
|
||||
t->Resize(inputDims); \
|
||||
CPUContext ctx; \
|
||||
math::RandGaussian<float, CPUContext>( \
|
||||
t->size(), 0, 1, t->mutable_data<float>(), &ctx); \
|
||||
} \
|
||||
ws.RunNetOnce(predictNet);
|
||||
|
||||
// initialize
|
||||
|
@ -16,7 +16,7 @@ void AddNoiseInput(const vector<int64_t>& shape, const string& name, Workspace*
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
|
@ -16,7 +16,7 @@ void AddNoiseInput(const vector<int64_t>& shape, const string& name, Workspace*
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
|
@ -679,7 +679,7 @@ void NNApi::init(const TensorVector& inputs, TensorVector* outputs) {
|
||||
output_dims.push_back(dim);
|
||||
}
|
||||
|
||||
auto* tensor = ws_.CreateBlob(blob)->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(ws_.CreateBlob(blob), CPU);
|
||||
tensor->Resize(output_dims);
|
||||
outputs->push_back(tensor);
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/tensor.h"
|
||||
@ -43,14 +43,14 @@ static double benchmark_conv_caffe2(
|
||||
ws = &localWs;
|
||||
}
|
||||
{
|
||||
auto* t = ws->CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws->CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws->CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws->CreateBlob("W"), CPU);
|
||||
if (group == 1) {
|
||||
t->Resize(K, C, kernel, kernel);
|
||||
} else {
|
||||
@ -61,7 +61,7 @@ static double benchmark_conv_caffe2(
|
||||
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws->CreateBlob("B")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws->CreateBlob("B"), CPU);
|
||||
t->Resize(K);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -129,14 +129,14 @@ static double benchmark_conv_nnapi(
|
||||
ws = &localWs;
|
||||
}
|
||||
{
|
||||
auto* t = ws->CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws->CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, H, W, C);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws->CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws->CreateBlob("W"), CPU);
|
||||
if (group > 1) {
|
||||
CAFFE_ENFORCE_EQ(C, group);
|
||||
t->Resize(1, kernel, kernel, C);
|
||||
@ -148,7 +148,7 @@ static double benchmark_conv_nnapi(
|
||||
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws->CreateBlob("B")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws->CreateBlob("B"), CPU);
|
||||
t->Resize(K);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -190,7 +190,7 @@ static double benchmark_conv_nnapi(
|
||||
NetDef initNet;
|
||||
NNApi model(initNet, netdef, ws);
|
||||
std::vector<TensorCPU*> inputs, outputs;
|
||||
inputs.push_back(ws->GetBlob("X_cpu")->GetMutableTensor(CPU));
|
||||
inputs.push_back(BlobGetMutableTensor(ws->GetBlob("X_cpu"), CPU));
|
||||
CAFFE_ENFORCE(model.run(inputs, &outputs));
|
||||
|
||||
for (int i = 0; i < warmup; i++) {
|
||||
@ -220,14 +220,14 @@ static double benchmark_conv_nnapi_int8(
|
||||
ws = &localWs;
|
||||
}
|
||||
{
|
||||
auto* t = ws->CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws->CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, H, W, C);
|
||||
for (int i = 0; i < t->size(); i++) {
|
||||
t->mutable_data<uint8_t>()[i] = rand() % 10;
|
||||
}
|
||||
}
|
||||
{
|
||||
auto* t = ws->CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws->CreateBlob("W"), CPU);
|
||||
if (group > 1) {
|
||||
CAFFE_ENFORCE_EQ(C, group);
|
||||
t->Resize(1, kernel, kernel, C);
|
||||
@ -243,7 +243,7 @@ static double benchmark_conv_nnapi_int8(
|
||||
// should be of ANEURALNETWORKS_TENSOR_INT32, with zeroPoint of 0 and
|
||||
// bias_scale == input_scale * filter_scale.
|
||||
{
|
||||
auto* t = ws->CreateBlob("B")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws->CreateBlob("B"), CPU);
|
||||
t->Resize(K);
|
||||
for (int i = 0; i < t->size(); i++) {
|
||||
t->mutable_data<int32_t>()[i] = rand() % 10;
|
||||
@ -322,7 +322,7 @@ static double benchmark_conv_nnapi_int8(
|
||||
NetDef initNet;
|
||||
NNApi model(initNet, netdef, ws);
|
||||
std::vector<TensorCPU*> inputs, outputs;
|
||||
inputs.push_back(ws->GetBlob("X_cpu")->GetMutableTensor(CPU));
|
||||
inputs.push_back(BlobGetMutableTensor(ws->GetBlob("X_cpu"), CPU));
|
||||
CAFFE_ENFORCE(model.run(inputs, &outputs));
|
||||
|
||||
for (int i = 0; i < warmup; i++) {
|
||||
|
@ -55,7 +55,7 @@ static void test_relu(int N, int C, int H, int W) {
|
||||
// CPU reference
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, H, W, C);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -81,7 +81,7 @@ static void test_relu(int N, int C, int H, int W) {
|
||||
NetDef initNet;
|
||||
NNApi model(initNet, netdef, &ws);
|
||||
std::vector<TensorCPU*> inputs, outputs;
|
||||
inputs.push_back(ws.GetBlob("X_cpu")->GetMutableTensor(CPU));
|
||||
inputs.push_back(BlobGetMutableTensor(ws.GetBlob("X_cpu"), CPU));
|
||||
EXPECT_TRUE(model.run(inputs, &outputs));
|
||||
const auto& t_nn = *outputs[0];
|
||||
|
||||
@ -103,21 +103,21 @@ static void test_conv_NHWC(
|
||||
int stride_w) {
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, H, W, C);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(K, kernel, kernel, C);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws.CreateBlob("B")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("B"), CPU);
|
||||
t->Resize(K);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -189,7 +189,7 @@ static void test_conv_NHWC(
|
||||
NetDef initNet;
|
||||
NNApi model(initNet, netdef, &ws);
|
||||
std::vector<TensorCPU*> inputs, outputs;
|
||||
inputs.push_back(ws.GetBlob("X_cpu")->GetMutableTensor(CPU));
|
||||
inputs.push_back(BlobGetMutableTensor(ws.GetBlob("X_cpu"), CPU));
|
||||
EXPECT_TRUE(model.run(inputs, &outputs));
|
||||
const auto& t_nn = *outputs[0];
|
||||
|
||||
@ -211,21 +211,21 @@ static void test_depthwise_conv_NHWC(
|
||||
int stride_w) {
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, H, W, C);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(1, kernel, kernel, D);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
t->size(), 0, 30, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
{
|
||||
auto* t = ws.CreateBlob("B")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("B"), CPU);
|
||||
t->Resize(D);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -406,7 +406,7 @@ static void test_depthwise_conv_NHWC(
|
||||
NetDef initNet;
|
||||
NNApi model(initNet, netdef, &ws);
|
||||
std::vector<TensorCPU*> inputs, outputs;
|
||||
inputs.push_back(ws.GetBlob("X_cpu")->GetMutableTensor(CPU));
|
||||
inputs.push_back(BlobGetMutableTensor(ws.GetBlob("X_cpu"), CPU));
|
||||
EXPECT_TRUE(model.run(inputs, &outputs));
|
||||
const auto& t_nn = *outputs[0];
|
||||
|
||||
@ -428,7 +428,7 @@ static void test_pooling(
|
||||
int stride_w) {
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, H, W, C);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
@ -496,7 +496,7 @@ static void test_pooling(
|
||||
NetDef initNet;
|
||||
NNApi model(initNet, netdef, &ws);
|
||||
std::vector<TensorCPU*> inputs, outputs;
|
||||
inputs.push_back(ws.GetBlob("X_cpu")->GetMutableTensor(CPU));
|
||||
inputs.push_back(BlobGetMutableTensor(ws.GetBlob("X_cpu"), CPU));
|
||||
EXPECT_TRUE(model.run(inputs, &outputs));
|
||||
const auto& t_nn = *outputs[0];
|
||||
|
||||
@ -506,7 +506,7 @@ static void test_pooling(
|
||||
static void test_softmax(int N, int C, int H = 1, int W = 1) {
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
if (H == 1 && W == 1) {
|
||||
t->Resize(N, C);
|
||||
} else {
|
||||
@ -538,7 +538,7 @@ static void test_softmax(int N, int C, int H = 1, int W = 1) {
|
||||
NetDef initNet;
|
||||
NNApi model(initNet, netdef, &ws);
|
||||
std::vector<TensorCPU*> inputs, outputs;
|
||||
inputs.push_back(ws.GetBlob("X_cpu")->GetMutableTensor(CPU));
|
||||
inputs.push_back(BlobGetMutableTensor(ws.GetBlob("X_cpu"), CPU));
|
||||
EXPECT_TRUE(model.run(inputs, &outputs));
|
||||
const auto& t_nn = *outputs[0];
|
||||
|
||||
|
@ -178,7 +178,7 @@ void testOpenGLCopyOps(int N, int C, int H, int W, float error, int tile_x = 1,
|
||||
LOG(INFO) << "OPENGLCopyFrom/To Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(t->size(), 0, 1, t->mutable_data<float>(), &ctx);
|
||||
@ -275,7 +275,7 @@ void testOpenGLConv(int N,
|
||||
<< " Op: " << glPoolOperationName[poolOp];
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
if (random_input) {
|
||||
@ -301,7 +301,7 @@ void testOpenGLConv(int N,
|
||||
}
|
||||
|
||||
if (poolOp != AveragePool && poolOp != MaxPool) {
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
if (poolOp != ConvTranspose && poolOp != ConvTransposePRelu && poolOp != ConvTransposeRelu) {
|
||||
t->Resize(K, C, kernel_h, kernel_w);
|
||||
} else {
|
||||
@ -343,7 +343,7 @@ void testOpenGLConv(int N,
|
||||
|
||||
// bias
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(K);
|
||||
CPUContext ctx;
|
||||
if (random_input) {
|
||||
@ -367,7 +367,7 @@ void testOpenGLConv(int N,
|
||||
}
|
||||
|
||||
if (poolOp == ConvPRelu || poolOp == ConvTransposePRelu) {
|
||||
auto* t = ws.CreateBlob("p")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("p"), CPU);
|
||||
t->Resize(K);
|
||||
CPUContext ctx;
|
||||
if (random_input) {
|
||||
@ -532,7 +532,7 @@ void testOpenGLPRelu(
|
||||
<< "C: " << C << ", H: " << H << ", W: " << W;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
// Too noisy.
|
||||
@ -541,7 +541,7 @@ void testOpenGLPRelu(
|
||||
|
||||
// prelu scale
|
||||
{
|
||||
auto* t = ws.CreateBlob("p")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("p"), CPU);
|
||||
t->Resize(prelu_size);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(t->size(), 0, 1, t->mutable_data<float>(), &ctx);
|
||||
@ -603,7 +603,7 @@ void testOpenGLRelu(int N, int C, int H, int W, int input_tile_x, int input_tile
|
||||
<< "C: " << C << ", H: " << H << ", W: " << W;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
// Too noisy.
|
||||
@ -664,13 +664,13 @@ void testOpenGLAdd(int N, int C, int H, int W, float error = 0.1, int input_tile
|
||||
<< "C: " << C << ", H: " << H << ", W: " << W;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t0 = ws.CreateBlob("X_cpu0")->GetMutableTensor(CPU);
|
||||
auto* t0 = BlobGetMutableTensor(ws.CreateBlob("X_cpu0"), CPU);
|
||||
t0->Resize(N, C, H, W);
|
||||
CPUContext ctx0;
|
||||
// Too noisy.
|
||||
math::RandGaussian<float, CPUContext>(t0->size(), 0, 30, t0->mutable_data<float>(), &ctx0);
|
||||
|
||||
auto* t1 = ws.CreateBlob("X_cpu1")->GetMutableTensor(CPU);
|
||||
auto* t1 = BlobGetMutableTensor(ws.CreateBlob("X_cpu1"), CPU);
|
||||
t1->Resize(N, C, H, W);
|
||||
CPUContext ctx1;
|
||||
// Too noisy.
|
||||
@ -750,13 +750,13 @@ void testOpenGLSub(int N, int C, int H, int W, float error = 0.1) {
|
||||
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t0 = ws.CreateBlob("X_cpu0")->GetMutableTensor(CPU);
|
||||
auto* t0 = BlobGetMutableTensor(ws.CreateBlob("X_cpu0"), CPU);
|
||||
t0->Resize(N, C, H, W);
|
||||
CPUContext ctx0;
|
||||
// Too noisy.
|
||||
math::RandGaussian<float, CPUContext>(t0->size(), 0, 30, t0->mutable_data<float>(), &ctx0);
|
||||
|
||||
auto* t1 = ws.CreateBlob("X_cpu1")->GetMutableTensor(CPU);
|
||||
auto* t1 = BlobGetMutableTensor(ws.CreateBlob("X_cpu1"), CPU);
|
||||
t1->Resize(N, C, H, W);
|
||||
CPUContext ctx1;
|
||||
// Too noisy.
|
||||
@ -814,8 +814,8 @@ void testOpenGLConcat(int N, std::vector<int> Cs, int H, int W, bool tiling = fa
|
||||
<< "H: " << H << ", W: " << W;
|
||||
Workspace ws;
|
||||
for (int i = 0; i < Cs.size(); i++) {
|
||||
auto* t =
|
||||
ws.CreateBlob("X_cpu" + caffe2::to_string(i))->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(
|
||||
ws.CreateBlob("X_cpu" + caffe2::to_string(i)), CPU);
|
||||
t->Resize(N, Cs[i], H, W);
|
||||
CPUContext ctx0;
|
||||
// Too noisy.
|
||||
@ -891,7 +891,7 @@ void testOpenGLSigmoid(int N, int C, int H, int W, float error) {
|
||||
<< "C: " << C << ", H: " << H << ", W: " << W;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
// Too noisy.
|
||||
@ -942,7 +942,7 @@ void testOpenGLTanh(int N, int C, int H, int W, float error) {
|
||||
<< "C: " << C << ", H: " << H << ", W: " << W;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(t->size(), 0, 2, t->mutable_data<float>(), &ctx);
|
||||
@ -992,14 +992,14 @@ void testOpenGLMul(int N, int C, int H, int W, float error) {
|
||||
<< "C: " << C << ", H: " << H << ", W: " << W;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(t->size(), -10, 10, t->mutable_data<float>(), &ctx);
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("B")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("B"), CPU);
|
||||
t->Resize(1);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(t->size(), -10, 10, t->mutable_data<float>(), &ctx);
|
||||
@ -1060,7 +1060,7 @@ void testOpenGLSoftmax(int N, int D, float error, bool tiled = false) {
|
||||
LOG(INFO) << "OpenGL Softmax Test "
|
||||
<< "N: " << N << " D: " << D << " Tiled:" << tiled;
|
||||
Workspace ws;
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
{
|
||||
t->Resize(N, D);
|
||||
CPUContext ctx;
|
||||
@ -1151,7 +1151,7 @@ void testOpenGLInstanceNorm(int N, int C, int H, int W, float error) {
|
||||
<< "C: " << C << ", H: " << H << ", W: " << W;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
// Too noisy.
|
||||
@ -1163,7 +1163,7 @@ void testOpenGLInstanceNorm(int N, int C, int H, int W, float error) {
|
||||
|
||||
// scale
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(C);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -1172,7 +1172,7 @@ void testOpenGLInstanceNorm(int N, int C, int H, int W, float error) {
|
||||
}
|
||||
// bias
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(C);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -1254,7 +1254,7 @@ void testOpenGLInstanceNormPRelu(int N, int C, int H, int W, float error) {
|
||||
<< "C: " << C << ", H: " << H << ", W: " << W;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
// Too noisy.
|
||||
@ -1266,7 +1266,7 @@ void testOpenGLInstanceNormPRelu(int N, int C, int H, int W, float error) {
|
||||
|
||||
// scale
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(C);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -1275,7 +1275,7 @@ void testOpenGLInstanceNormPRelu(int N, int C, int H, int W, float error) {
|
||||
}
|
||||
// bias
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(C);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -1284,7 +1284,7 @@ void testOpenGLInstanceNormPRelu(int N, int C, int H, int W, float error) {
|
||||
}
|
||||
// prelu scale
|
||||
{
|
||||
auto* t = ws.CreateBlob("p")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("p"), CPU);
|
||||
t->Resize(C);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(t->size(), 0, 1, t->mutable_data<float>(), &ctx);
|
||||
@ -1385,7 +1385,7 @@ void OpenGL_speedtest(int N,
|
||||
<< " C: " << C << " H: " << H << " W: " << W;
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
if (random_input) {
|
||||
@ -1399,7 +1399,7 @@ void OpenGL_speedtest(int N,
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(K, C, kernel_h, kernel_w);
|
||||
CPUContext ctx;
|
||||
if (random_input) {
|
||||
@ -1413,7 +1413,7 @@ void OpenGL_speedtest(int N,
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(K);
|
||||
CPUContext ctx;
|
||||
if (random_input) {
|
||||
@ -1479,7 +1479,7 @@ void testOpenGLPadImage(
|
||||
{
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(t->size(), 0, 1, t->mutable_data<float>(), &ctx);
|
||||
@ -1593,7 +1593,7 @@ void testOpenGLResize(int N,
|
||||
{
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
math::RandGaussian<float, CPUContext>(t->size(), 0, 1, t->mutable_data<float>(), &ctx);
|
||||
@ -1675,7 +1675,7 @@ void testOpenGLPreprocess(int N, int C, int H, int W, float error) {
|
||||
LOG(INFO) << "OpenGL Preprocess Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, H, W, C);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -1684,7 +1684,7 @@ void testOpenGLPreprocess(int N, int C, int H, int W, float error) {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
|
||||
t->Resize(3);
|
||||
CPUContext ctx;
|
||||
t->mutable_data<float>()[0] = 100;
|
||||
@ -1748,7 +1748,7 @@ void testOpenGLDeprocess(int N, int C, int H, int W, float error) {
|
||||
LOG(INFO) << "OpenGLDeprocess Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -1757,7 +1757,7 @@ void testOpenGLDeprocess(int N, int C, int H, int W, float error) {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
|
||||
t->Resize(3);
|
||||
CPUContext ctx;
|
||||
t->mutable_data<float>()[0] = 30;
|
||||
@ -1800,7 +1800,7 @@ void testOpenGLNormPlanarYUV(int N, int C, int H, int W, float error) {
|
||||
LOG(INFO) << "OpenGLNormPlanarYUV Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, 3, H, W);
|
||||
CPUContext ctx;
|
||||
for (auto i = 0; i < t->size(); ++i) {
|
||||
@ -1809,7 +1809,7 @@ void testOpenGLNormPlanarYUV(int N, int C, int H, int W, float error) {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("mean")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("mean"), CPU);
|
||||
t->Resize(1, 3);
|
||||
CPUContext ctx;
|
||||
t->mutable_data<float>()[0] = 30;
|
||||
@ -1818,7 +1818,7 @@ void testOpenGLNormPlanarYUV(int N, int C, int H, int W, float error) {
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("stdev")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("stdev"), CPU);
|
||||
t->Resize(1, 3);
|
||||
CPUContext ctx;
|
||||
t->mutable_data<float>()[0] = 6;
|
||||
@ -1879,7 +1879,7 @@ void OpenGL_copyops_speedtest(int N,
|
||||
LOG(INFO) << "OpenGL CopyOps Speed Test";
|
||||
Workspace ws;
|
||||
{
|
||||
auto* t = ws.CreateBlob("X_cpu")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("X_cpu"), CPU);
|
||||
t->Resize(N, C, H, W);
|
||||
CPUContext ctx;
|
||||
if (random_input) {
|
||||
@ -1893,7 +1893,7 @@ void OpenGL_copyops_speedtest(int N,
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
t->Resize(K, C, kernel_h, kernel_w);
|
||||
CPUContext ctx;
|
||||
if (random_input) {
|
||||
@ -1907,7 +1907,7 @@ void OpenGL_copyops_speedtest(int N,
|
||||
}
|
||||
|
||||
{
|
||||
auto* t = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
t->Resize(K);
|
||||
CPUContext ctx;
|
||||
if (random_input) {
|
||||
@ -1990,8 +1990,8 @@ void compareModelsForOpenGL(std::string name,
|
||||
Workspace cws;
|
||||
cws.RunNetOnce(initNet);
|
||||
|
||||
auto* t_cpu = cws.CreateBlob(truncatedPredictNet.external_input(0))
|
||||
->GetMutableTensor(CPU);
|
||||
auto* t_cpu = BlobGetMutableTensor(
|
||||
cws.CreateBlob(truncatedPredictNet.external_input(0)), CPU);
|
||||
if (name == "styleTransfer") {
|
||||
CAFFE_ENFORCE_EQ(input_order, "NHWC");
|
||||
CAFFE_ENFORCE_EQ(input_type, "uint8_t");
|
||||
@ -2032,8 +2032,8 @@ void compareModelsForOpenGL(std::string name,
|
||||
Workspace mws;
|
||||
mws.RunNetOnce(initNet);
|
||||
|
||||
auto* t_gl = mws.CreateBlob(truncatedOpenGLPredictNet.external_input(0))
|
||||
->GetMutableTensor(CPU);
|
||||
auto* t_gl = BlobGetMutableTensor(
|
||||
mws.CreateBlob(truncatedOpenGLPredictNet.external_input(0)), CPU);
|
||||
if (name == "styleTransfer") {
|
||||
CAFFE_ENFORCE_EQ(input_order, "NHWC");
|
||||
CAFFE_ENFORCE_EQ(input_type, "uint8_t");
|
||||
@ -2116,7 +2116,7 @@ void compareBatchedToTiledModels(std::string name,
|
||||
tws.RunNetOnce(initNet);
|
||||
|
||||
auto* t_batch =
|
||||
tws.CreateBlob(bachedNet.external_input(0))->GetMutableTensor(CPU);
|
||||
BlobGetMutableTensor(tws.CreateBlob(bachedNet.external_input(0)), CPU);
|
||||
if (name == "styleTransfer") {
|
||||
CAFFE_ENFORCE_EQ(input_order, "NHWC");
|
||||
CAFFE_ENFORCE_EQ(input_type, "uint8_t");
|
||||
@ -2143,7 +2143,7 @@ void compareBatchedToTiledModels(std::string name,
|
||||
bws.RunNetOnce(initNet);
|
||||
|
||||
auto* t_tiling =
|
||||
bws.CreateBlob(tiledNet.external_input(0))->GetMutableTensor(CPU);
|
||||
BlobGetMutableTensor(bws.CreateBlob(tiledNet.external_input(0)), CPU);
|
||||
if (name == "styleTransfer") {
|
||||
CAFFE_ENFORCE_EQ(input_order, "NHWC");
|
||||
CAFFE_ENFORCE_EQ(input_type, "uint8_t");
|
||||
|
@ -14,7 +14,7 @@
|
||||
#define POPULATE_DATA(_n, _s, _l) \
|
||||
do { \
|
||||
Blob* _blob = ws.CreateBlob((_n)); \
|
||||
auto* _tensor = _blob->GetMutableTensor(CPU); \
|
||||
auto* _tensor = BlobGetMutableTensor(_blob, CPU); \
|
||||
_tensor->Resize((_s)); \
|
||||
memcpy(_tensor->mutable_data<float>(), data_##_l, _tensor->nbytes()); \
|
||||
} while (0)
|
||||
@ -23,7 +23,7 @@
|
||||
#define POPULATE_DATA(_n, _s, _l) \
|
||||
do { \
|
||||
Blob* _blob = ws.CreateBlob((_n)); \
|
||||
auto* _tensor = _blob->GetMutableTensor(CPU); \
|
||||
auto* _tensor = BlobGetMutableTensor(_blob, CPU); \
|
||||
_tensor->Resize((_s)); \
|
||||
memset(_tensor->mutable_data<float>(), 1, _tensor->nbytes()); \
|
||||
} while (0)
|
||||
@ -43,7 +43,7 @@ void AddConstInput(const vector<int64_t>& shape,
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
math::Set<float, CPUContext>(tensor->size(), value,
|
||||
tensor->mutable_data<float>(),
|
||||
@ -56,7 +56,7 @@ void AddNoiseInput(const vector<int64_t>& shape,
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
|
@ -289,13 +289,13 @@ void ConvTest2b1b(int IC, int KH, int KW, int H, int W, int OC, int N, ConvArgs
|
||||
def.add_arg()->CopyFrom(MakeArgument("pad_r", args.pad_r));
|
||||
def.add_arg()->CopyFrom(MakeArgument("pad_t", args.pad_t));
|
||||
def.add_arg()->CopyFrom(MakeArgument("pad_b", args.pad_b));
|
||||
auto* Xws = ws.CreateBlob("X")->GetMutableTensor(CPU);
|
||||
auto* Xws = BlobGetMutableTensor(ws.CreateBlob("X"), CPU);
|
||||
Xws->ResizeLike(X);
|
||||
Xws->ShareExternalPointer(X.mutable_data<float>(), X.size());
|
||||
auto* Wws = ws.CreateBlob("W")->GetMutableTensor(CPU);
|
||||
auto* Wws = BlobGetMutableTensor(ws.CreateBlob("W"), CPU);
|
||||
Wws->ResizeLike(W_);
|
||||
Wws->ShareExternalPointer(W_.mutable_data<float>(), W_.size());
|
||||
auto* bws = ws.CreateBlob("b")->GetMutableTensor(CPU);
|
||||
auto* bws = BlobGetMutableTensor(ws.CreateBlob("b"), CPU);
|
||||
bws->ResizeLike(bias);
|
||||
bws->ShareExternalPointer(bias.mutable_data<float>(), bias.size());
|
||||
ws.RunOperatorOnce(def);
|
||||
|
@ -30,7 +30,7 @@ class BatchMatMulOpGPUTest : public testing::Test {
|
||||
const float value,
|
||||
const string& name) {
|
||||
Blob* blob = ws_.CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CUDA);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CUDA);
|
||||
tensor->Resize(dims);
|
||||
math::Set<float, CUDAContext>(
|
||||
tensor->size(),
|
||||
|
@ -24,7 +24,7 @@ class BatchMatMulOpTest : public testing::Test {
|
||||
const float value,
|
||||
const string& name) {
|
||||
Blob* blob = ws_.CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(dims);
|
||||
math::Set<float, CPUContext>(
|
||||
tensor->size(),
|
||||
|
@ -16,7 +16,7 @@ static void AddScalarInput(
|
||||
Workspace* ws,
|
||||
bool isEmpty = false) {
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
if (!isEmpty) {
|
||||
tensor->Resize(vector<int64_t>{1});
|
||||
*(tensor->template mutable_data<DataT>()) = value;
|
||||
|
@ -27,8 +27,8 @@ void runWithSharedBuffer<CPUContext>(
|
||||
|
||||
auto* mutexPtr = mutexBlob->GetMutable<std::unique_ptr<std::mutex>>();
|
||||
std::lock_guard<std::mutex> g(**mutexPtr);
|
||||
auto* buffer =
|
||||
ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CPU__")->GetMutableTensor(CPU);
|
||||
auto* buffer = BlobGetMutableTensor(
|
||||
ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CPU__"), CPU);
|
||||
f(buffer);
|
||||
}
|
||||
}
|
||||
|
@ -20,8 +20,8 @@ void runWithSharedBuffer<CUDAContext>(
|
||||
|
||||
auto* mutexPtr = mutexBlob->GetMutable<std::unique_ptr<std::mutex>>();
|
||||
std::lock_guard<std::mutex> g(**mutexPtr);
|
||||
auto* buffer =
|
||||
ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA__")->GetMutableTensor(CUDA);
|
||||
auto* buffer = BlobGetMutableTensor(
|
||||
ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA__"), CUDA);
|
||||
f(buffer);
|
||||
}
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ void AddConstInput(const vector<int64_t>& shape,
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
math::Set<float, CPUContext>(
|
||||
tensor->size(), value, tensor->template mutable_data<float>(), &context);
|
||||
@ -29,7 +29,7 @@ void AddNoiseInput(const vector<int64_t>& shape,
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
|
@ -1428,7 +1428,7 @@ class TreeCursorSerializer : public BlobSerializerBase {
|
||||
// serialize offsets as a tensor
|
||||
if (cursor->offsets.size() > 0) {
|
||||
Blob offsets_blob;
|
||||
auto* offsets = offsets_blob.GetMutableTensor(CPU);
|
||||
auto* offsets = BlobGetMutableTensor(&offsets_blob, CPU);
|
||||
offsets->Resize(cursor->offsets.size());
|
||||
std::copy(
|
||||
cursor->offsets.begin(),
|
||||
|
@ -150,7 +150,7 @@ bool CuDNNDropoutOp::DoRunWithType() {
|
||||
// Reshape tensor descriptors if necessary
|
||||
if (X.dims() != cudnn_input_dims_ && !is_test_) {
|
||||
CAFFE_ENFORCE(scratch_blob_);
|
||||
Tensor* states = scratch_blob_->GetMutableTensor(CUDA);
|
||||
Tensor* states = BlobGetMutableTensor(scratch_blob_, CUDA);
|
||||
cudnn_input_dims_ = X.dims();
|
||||
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
|
||||
data_desc_,
|
||||
|
@ -19,7 +19,7 @@ void FillTensor(
|
||||
const std::vector<int64_t>& shape,
|
||||
const std::vector<I_Type>& values) {
|
||||
auto* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* tensor = BlobGetMutableTensor(blob, Context::GetDeviceType());
|
||||
tensor->Resize(shape);
|
||||
auto* mutable_data = tensor->template mutable_data<O_Type>();
|
||||
const O_Type* data = reinterpret_cast<const O_Type*>(values.data());
|
||||
|
@ -18,7 +18,7 @@ static void AddConstInput(
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
math::Set<float, CPUContext>(
|
||||
tensor->size(), value, tensor->template mutable_data<float>(), &context);
|
||||
@ -34,7 +34,7 @@ static void AddLinSpacedInput(
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
EigenVectorMap<float> tensor_vec(
|
||||
tensor->template mutable_data<float>(), tensor->size());
|
||||
@ -51,7 +51,7 @@ static void AddInput(
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
EigenVectorMap<float> tensor_vec(
|
||||
tensor->template mutable_data<float>(), tensor->size());
|
||||
|
@ -353,7 +353,7 @@ class IndexSerializer : public BlobSerializerBase {
|
||||
SerializationAcceptor acceptor) override {
|
||||
auto& base = blob.template Get<std::unique_ptr<IndexBase>>();
|
||||
Blob tensor_blob;
|
||||
auto* tensor_out = tensor_blob.GetMutableTensor(CPU);
|
||||
auto* tensor_out = BlobGetMutableTensor(&tensor_blob, CPU);
|
||||
|
||||
if (base->Type().Match<std::string>()) {
|
||||
doStore<std::string>(base, tensor_out);
|
||||
|
@ -213,23 +213,23 @@ class ONNXWhileOp final : public Operator<Context> {
|
||||
lcd_tensors_.clear();
|
||||
for (int i = 2; i < body_net_def.external_input_size(); ++i) {
|
||||
Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i));
|
||||
Tensor* t = b->GetMutableTensor(Context::GetDeviceType());
|
||||
Tensor* t = BlobGetMutableTensor(b, Context::GetDeviceType());
|
||||
lcd_tensors_.push_back(t);
|
||||
}
|
||||
// First output is the iteration variable
|
||||
auto* iteration_var_blob = loop_ws_->CreateBlob(
|
||||
body_net_def.external_input(0));
|
||||
iteration_var_ =
|
||||
iteration_var_blob->GetMutableTensor(Context::GetDeviceType());
|
||||
BlobGetMutableTensor(iteration_var_blob, Context::GetDeviceType());
|
||||
|
||||
input_condition_var_ =
|
||||
loop_ws_->CreateBlob(body_net_def.external_input(1))
|
||||
->GetMutableTensor(Context::GetDeviceType());
|
||||
input_condition_var_ = BlobGetMutableTensor(
|
||||
loop_ws_->CreateBlob(body_net_def.external_input(1)),
|
||||
Context::GetDeviceType());
|
||||
|
||||
auto* condition_var_blob =
|
||||
loop_ws_->CreateBlob(body_net_def.external_output(0));
|
||||
condition_var_ =
|
||||
condition_var_blob->GetMutableTensor(Context::GetDeviceType());
|
||||
BlobGetMutableTensor(condition_var_blob, Context::GetDeviceType());
|
||||
condition_var_->Resize(1);
|
||||
condition_var_->template mutable_data<bool>();
|
||||
|
||||
|
@ -15,7 +15,7 @@ void BlobToTensorDescriptor(
|
||||
// Memory type
|
||||
// We only allow weights to be CPU tensor for now
|
||||
CAFFE_ENFORCE(
|
||||
blob->IsTensorType(CPU),
|
||||
BlobIsTensorType(*blob, CPU),
|
||||
"Initialization blob ",
|
||||
name,
|
||||
" needs to be TensorCPU");
|
||||
|
@ -65,8 +65,8 @@ class GPUFallbackOpEx final : public Operator<CUDAContext> {
|
||||
bool need_sync = false;
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
if (this->InputIsTensorType(i, CUDA)) {
|
||||
local_input_blobs_[i]->GetMutableTensor(CPU)->CopyFrom(
|
||||
Input(i), &context_);
|
||||
BlobGetMutableTensor(local_input_blobs_[i], CPU)
|
||||
->CopyFrom(Input(i), &context_);
|
||||
need_sync = true;
|
||||
} else {
|
||||
VLOG(1) << "Input " << i << " is not TensorCUDA. Skipping copy.";
|
||||
@ -95,7 +95,7 @@ class GPUFallbackOpEx final : public Operator<CUDAContext> {
|
||||
continue;
|
||||
}
|
||||
CAFFE_ENFORCE(
|
||||
local_output_blobs_[i]->IsTensorType(CPU),
|
||||
BlobIsTensorType(*local_output_blobs_[i], CPU),
|
||||
"GPU fallback op currently does not support non-TensorCPU "
|
||||
"output type who needs copying.");
|
||||
Output(i)->CopyFrom(local_output_blobs_[i]->template Get<TensorCPU>());
|
||||
|
@ -40,7 +40,7 @@ TEST(OperatorFallbackTest, IncrementByOneOp) {
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
source_tensor.mutable_data<float>()[i] = i;
|
||||
}
|
||||
ws.CreateBlob("X")->GetMutableTensor(CPU)->CopyFrom(source_tensor);
|
||||
BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(source_tensor);
|
||||
unique_ptr<OperatorBase> op(CreateOperator(op_def, &ws));
|
||||
EXPECT_TRUE(op.get() != nullptr);
|
||||
EXPECT_TRUE(op->Run());
|
||||
@ -64,7 +64,7 @@ TEST(OperatorFallbackTest, GPUIncrementByOneOp) {
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
source_tensor.mutable_data<float>()[i] = i;
|
||||
}
|
||||
ws.CreateBlob("X")->GetMutableTensor(CUDA)->CopyFrom(source_tensor);
|
||||
BlobGetMutableTensor(ws.CreateBlob("X"), CUDA)->CopyFrom(source_tensor);
|
||||
unique_ptr<OperatorBase> op(CreateOperator(op_def, &ws));
|
||||
EXPECT_TRUE(op.get() != nullptr);
|
||||
EXPECT_TRUE(op->Run());
|
||||
|
@ -20,7 +20,7 @@ static void AddConstInput(
|
||||
option.set_device_type(PROTO_CUDA);
|
||||
CUDAContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CUDA);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CUDA);
|
||||
tensor->Resize(shape);
|
||||
math::Set<float, CUDAContext>(
|
||||
tensor->size(), value, tensor->template mutable_data<float>(), &context);
|
||||
|
@ -43,11 +43,10 @@ class RecurrentNetworkBlobFetcherOp final : public Operator<Context> {
|
||||
prefix_ + std::string("_") + blob_name + caffe2::to_string(i);
|
||||
blob_names_vector.push_back(newBlobName);
|
||||
|
||||
ws_->CreateBlob(newBlobName)
|
||||
->GetMutableTensor(CPU)
|
||||
BlobGetMutableTensor(ws_->CreateBlob(newBlobName), CPU)
|
||||
->ResizeLike(currentTensor);
|
||||
auto type = Context::GetDeviceType();
|
||||
auto* newTensor = ws_->GetBlob(newBlobName)->GetMutableTensor(type);
|
||||
auto* newTensor = BlobGetMutableTensor(ws_->GetBlob(newBlobName), type);
|
||||
newTensor->CopyFrom(currentTensor);
|
||||
}
|
||||
}
|
||||
|
@ -111,10 +111,10 @@ class RecurrentNetworkExecutorBase {
|
||||
// the forward-only mode.
|
||||
std::string this_timestep_blob =
|
||||
timestep_blob_ + "_rnnexec_t" + caffe2::to_string(t);
|
||||
ws->CreateBlob(this_timestep_blob)->GetMutableTensor(CPU)->Resize(1);
|
||||
BlobGetMutableTensor(ws->CreateBlob(this_timestep_blob), CPU)->Resize(1);
|
||||
auto b = ws->GetBlob(this_timestep_blob);
|
||||
CAFFE_ENFORCE(b);
|
||||
b->GetMutableTensor(CPU)->template mutable_data<int32_t>()[0] = t;
|
||||
BlobGetMutableTensor(b, CPU)->template mutable_data<int32_t>()[0] = t;
|
||||
|
||||
// Copy the operators from template
|
||||
for (auto& template_rnn_op : timestep_ops_template_) {
|
||||
|
@ -52,10 +52,11 @@ struct CAFFE2_API ScratchWorkspaces {
|
||||
};
|
||||
|
||||
inline void UpdateTimestepBlob(Workspace* ws, std::string blob_name, int t) {
|
||||
ws->CreateBlob(blob_name)->GetMutableTensor(CPU)->Resize(1);
|
||||
BlobGetMutableTensor(ws->CreateBlob(blob_name), CPU)->Resize(1);
|
||||
auto timestepBlob = ws->GetBlob(blob_name);
|
||||
CAFFE_ENFORCE(timestepBlob);
|
||||
timestepBlob->GetMutableTensor(CPU)->template mutable_data<int32_t>()[0] = t;
|
||||
BlobGetMutableTensor(timestepBlob, CPU)->template mutable_data<int32_t>()[0] =
|
||||
t;
|
||||
}
|
||||
|
||||
CAFFE2_API std::map<string, string> GetRecurrentMapping(
|
||||
@ -71,8 +72,9 @@ void applyOffsetAlias(
|
||||
<< " at offset: " << oc.offset;
|
||||
auto srcBlob = ws->GetBlob(oc.src);
|
||||
CAFFE_ENFORCE(srcBlob);
|
||||
auto* src = srcBlob->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* dst = ws->GetBlob(oc.dst)->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* src = BlobGetMutableTensor(srcBlob, Context::GetDeviceType());
|
||||
auto* dst =
|
||||
BlobGetMutableTensor(ws->GetBlob(oc.dst), Context::GetDeviceType());
|
||||
auto timestep = src->size() / src->dim(0);
|
||||
auto dims = src->dims();
|
||||
const int32_t startDstTimestep =
|
||||
@ -113,7 +115,7 @@ void initializeRecurrentInput(
|
||||
Context* context) {
|
||||
auto stateBlob = ws->GetBlob(rc.state);
|
||||
CAFFE_ENFORCE(stateBlob);
|
||||
auto* state = stateBlob->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* state = BlobGetMutableTensor(stateBlob, Context::GetDeviceType());
|
||||
|
||||
auto inputBlob = ws->GetBlob(rc.input);
|
||||
CAFFE_ENFORCE(inputBlob);
|
||||
@ -660,7 +662,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
|
||||
|
||||
auto gBlob = sharedWs_->GetBlob(param.grad);
|
||||
CAFFE_ENFORCE(gBlob);
|
||||
auto* g = gBlob->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
|
||||
g->ResizeLike(p);
|
||||
math::Set<T, Context>(
|
||||
g->size(),
|
||||
@ -676,7 +678,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
|
||||
|
||||
auto gBlob = sharedWs_->CreateBlob(rg.grad);
|
||||
CAFFE_ENFORCE(gBlob);
|
||||
auto* g = gBlob->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
|
||||
g->ResizeLike(p);
|
||||
CAFFE_ENFORCE_EQ(g->ndim(), 3);
|
||||
const auto timestep = g->size() / g->dim(0);
|
||||
@ -703,7 +705,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
|
||||
<< ". Size: " << Input(gradientInputIndex).size();
|
||||
auto pGradientBlob = sharedWs_->GetBlob(gradientName);
|
||||
CAFFE_ENFORCE(pGradientBlob);
|
||||
auto* g = pGradientBlob->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* g = BlobGetMutableTensor(pGradientBlob, Context::GetDeviceType());
|
||||
g->ResizeLike(Input(gradientInputIndex));
|
||||
g->template mutable_data<T>();
|
||||
}
|
||||
@ -717,7 +719,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
|
||||
<< rg.lastExternalGrad << " for final time step (sep. blob)";
|
||||
auto gBlob = sharedWs_->GetBlob(rg.grad);
|
||||
CAFFE_ENFORCE(gBlob);
|
||||
auto* g = gBlob->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
|
||||
|
||||
auto oglastBlob = sharedWs_->GetBlob(rg.lastExternalGrad);
|
||||
CAFFE_ENFORCE(oglastBlob);
|
||||
@ -779,7 +781,7 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
|
||||
T* output_data = Output(outputIdx)->template mutable_data<T>();
|
||||
auto pBlob = sharedWs_->GetBlob(recurrentGradients_[i].grad);
|
||||
CAFFE_ENFORCE(pBlob);
|
||||
auto* p = pBlob->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* p = BlobGetMutableTensor(pBlob, Context::GetDeviceType());
|
||||
|
||||
if (Input(inputId).ndim() >= 2) {
|
||||
// Gradient states blob should live. And if it gets changed by the
|
||||
|
@ -18,7 +18,7 @@ void AddConstInput(
|
||||
Context* context,
|
||||
Workspace* ws) {
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(Context::GetDeviceType());
|
||||
auto* tensor = BlobGetMutableTensor(blob, Context::GetDeviceType());
|
||||
tensor->Resize(shape);
|
||||
math::Set<float, Context>(
|
||||
tensor->size(), value, tensor->template mutable_data<float>(), context);
|
||||
@ -39,7 +39,7 @@ void AddInput<CPUContext>(
|
||||
const string& name,
|
||||
Workspace* ws) {
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
EigenVectorMap<float> tensor_vec(
|
||||
tensor->template mutable_data<float>(), tensor->size());
|
||||
@ -57,7 +57,7 @@ void AddInput<CUDAContext>(
|
||||
tmp_vec.array() = utils::AsEArrXt(values);
|
||||
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CUDA);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CUDA);
|
||||
tensor->CopyFrom(tmp);
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,7 @@ class StringJoinOpTest : public testing::Test {
|
||||
public:
|
||||
bool runOp(const TensorCPU& input) {
|
||||
auto* blob = ws_.CreateBlob("X");
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->ResizeLike(input);
|
||||
tensor->ShareData(input);
|
||||
|
||||
@ -26,7 +26,7 @@ class StringJoinOpTest : public testing::Test {
|
||||
const std::string* checkAndGetOutput(int outputSize) {
|
||||
const auto* output = ws_.GetBlob("Y");
|
||||
EXPECT_NE(output, nullptr);
|
||||
EXPECT_TRUE(output->IsTensorType(CPU));
|
||||
EXPECT_TRUE(BlobIsTensorType(*output, CPU));
|
||||
const auto& outputTensor = output->Get<TensorCPU>();
|
||||
EXPECT_EQ(outputTensor.ndim(), 1);
|
||||
EXPECT_EQ(outputTensor.dim(0), outputSize);
|
||||
@ -42,7 +42,7 @@ TEST_F(StringJoinOpTest, testString1DJoin) {
|
||||
std::vector<std::string> input = {"a", "xx", "c"};
|
||||
|
||||
auto blob = caffe2::make_unique<Blob>();
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob.get(), CPU);
|
||||
tensor->Resize(input.size());
|
||||
auto* data = tensor->template mutable_data<std::string>();
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
@ -62,7 +62,7 @@ TEST_F(StringJoinOpTest, testString2DJoin) {
|
||||
{"dd", "ee", "ff"}};
|
||||
|
||||
auto blob = caffe2::make_unique<Blob>();
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob.get(), CPU);
|
||||
tensor->Resize(input.size(), input[0].size());
|
||||
auto* data = tensor->template mutable_data<std::string>();
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
@ -82,7 +82,7 @@ TEST_F(StringJoinOpTest, testFloat1DJoin) {
|
||||
std::vector<float> input = {3.90f, 5.234f, 8.12f};
|
||||
|
||||
auto blob = caffe2::make_unique<Blob>();
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob.get(), CPU);
|
||||
tensor->Resize(input.size());
|
||||
auto* data = tensor->template mutable_data<float>();
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
@ -102,7 +102,7 @@ TEST_F(StringJoinOpTest, testFloat2DJoin) {
|
||||
{4.67f, 5.90f, 6.32f}};
|
||||
|
||||
auto blob = caffe2::make_unique<Blob>();
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob.get(), CPU);
|
||||
tensor->Resize(input.size(), input[0].size());
|
||||
auto* data = tensor->template mutable_data<float>();
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
@ -122,7 +122,7 @@ TEST_F(StringJoinOpTest, testLong2DJoin) {
|
||||
std::vector<std::vector<int64_t>> input = {{100, 200}, {1000, 2000}};
|
||||
|
||||
auto blob = caffe2::make_unique<Blob>();
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob.get(), CPU);
|
||||
tensor->Resize(input.size(), input[0].size());
|
||||
auto* data = tensor->template mutable_data<int64_t>();
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
|
@ -82,10 +82,10 @@ class PackedInt8BGRANHWCToNCHWCStylizerPreprocessOp
|
||||
auto defaultNoiseSize = OperatorBase::GetSingleArgument<int>(
|
||||
"noise_size", 491 /* prime to avoid artifacts */);
|
||||
|
||||
if (!noiseBlob->IsTensorType(CPU)) {
|
||||
if (!BlobIsTensorType(*noiseBlob, CPU)) {
|
||||
// Initialize random noise on first use.
|
||||
// Cache it to maintain temporal consistency.
|
||||
auto* t = noiseBlob->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(noiseBlob, CPU);
|
||||
|
||||
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
|
||||
// Noise space is larger for vectorized code due to the
|
||||
|
@ -56,7 +56,7 @@ bool TensorProtosDBInput<Context>::Prefetch() {
|
||||
protos.mutable_protos(i)->clear_device_detail();
|
||||
}
|
||||
deserializer.Deserialize(
|
||||
protos.protos(i), prefetched_blobs_[i].GetMutableTensor(CPU));
|
||||
protos.protos(i), BlobGetMutableTensor(&prefetched_blobs_[i], CPU));
|
||||
}
|
||||
} else {
|
||||
vector<Tensor> temp_tensors;
|
||||
@ -74,11 +74,11 @@ bool TensorProtosDBInput<Context>::Prefetch() {
|
||||
vector<int> dims(
|
||||
protos.protos(i).dims().begin(), protos.protos(i).dims().end());
|
||||
dims.insert(dims.begin(), batch_size_);
|
||||
prefetched_blobs_[i].GetMutableTensor(CPU)->Resize(dims);
|
||||
BlobGetMutableTensor(&prefetched_blobs_[i], CPU)->Resize(dims);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < protos.protos_size(); ++i) {
|
||||
TensorCPU* dst = prefetched_blobs_[i].GetMutableTensor(CPU);
|
||||
TensorCPU* dst = BlobGetMutableTensor(&prefetched_blobs_[i], CPU);
|
||||
TensorCPU& src = temp_tensors[i];
|
||||
if (protos.protos(i).has_device_detail()) {
|
||||
protos.mutable_protos(i)->clear_device_detail();
|
||||
|
@ -52,7 +52,7 @@ class TTLinearOp final : public Operator<Context> {
|
||||
int cores_idx = 0;
|
||||
|
||||
// Temporary buffer to facilitate multiplication of TT-cores with input
|
||||
auto Y_buf = Y_temp_->GetMutableTensor(Context::GetDeviceType());
|
||||
auto Y_buf = BlobGetMutableTensor(Y_temp_.get(), Context::GetDeviceType());
|
||||
Y_buf->ResizeLike(X);
|
||||
Y_buf->CopyFrom(X);
|
||||
|
||||
|
@ -19,7 +19,7 @@ static void AddConstInput(
|
||||
option.set_device_type(PROTO_CUDA);
|
||||
CUDAContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CUDA);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CUDA);
|
||||
tensor->Resize(shape);
|
||||
math::Set<float, CUDAContext>(
|
||||
tensor->size(), value, tensor->template mutable_data<float>(), &context);
|
||||
|
@ -16,7 +16,7 @@ static void AddConstInput(
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
math::Set<float, CPUContext>(
|
||||
tensor->size(), value, tensor->template mutable_data<float>(), &context);
|
||||
|
@ -44,10 +44,10 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
|
||||
CAFFE_ENFORCE(
|
||||
bnInputs.size() >= 5, "Invalid batch normalization input size");
|
||||
|
||||
#define EXPOSE_TENSOR_DATA(name, index, inputs) \
|
||||
auto name = repr::nn::get<repr::Tensor>(inputs[index]); \
|
||||
assert(ws->HasBlob(name->getName()) && "Blob not in workspace"); \
|
||||
auto name##Tensor = ws->GetBlob(name->getName())->GetMutableTensor(CPU); \
|
||||
#define EXPOSE_TENSOR_DATA(name, index, inputs) \
|
||||
auto name = repr::nn::get<repr::Tensor>(inputs[index]); \
|
||||
assert(ws->HasBlob(name->getName()) && "Blob not in workspace"); \
|
||||
auto name##Tensor = BlobGetMutableTensor(ws->GetBlob(name->getName()), CPU); \
|
||||
auto name##Data = name##Tensor->mutable_data<float>();
|
||||
|
||||
EXPOSE_TENSOR_DATA(filter, 1, convInputs);
|
||||
@ -76,7 +76,7 @@ bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
|
||||
nn->dataFlow.createEdge(convBiasNode, convNode);
|
||||
|
||||
auto* blob = ws->CreateBlob(convBiasName);
|
||||
caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
// Get output channel
|
||||
size_t c = filterTensor->dim32(0);
|
||||
|
@ -173,7 +173,7 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOp(
|
||||
|
||||
// Feed into workspace as CPU Tensors
|
||||
auto* blob = ws->CreateBlob(t.name());
|
||||
auto* cpu_tensor = blob->GetMutableTensor(CPU);
|
||||
auto* cpu_tensor = BlobGetMutableTensor(blob, CPU);
|
||||
std::vector<int64_t> dims;
|
||||
for(const auto& d : t.dims()) {
|
||||
dims.push_back(d);
|
||||
|
@ -10,14 +10,14 @@ void enforceIsTensor(Workspace* ws, const std::string& name) {
|
||||
auto blob = ws->GetBlob(name);
|
||||
CAFFE_ENFORCE(blob, "Blob does not exist: ", name);
|
||||
CAFFE_ENFORCE(
|
||||
blob->IsTensorType(CPU), "Blob is not a CPU Tensor: ", name);
|
||||
BlobIsTensorType(*blob, CPU), "Blob is not a CPU Tensor: ", name);
|
||||
}
|
||||
|
||||
TensorCPU* getTensor(Workspace* ws, const std::string& name) {
|
||||
enforceIsTensor(ws, name);
|
||||
auto* blob = ws->GetBlob(name);
|
||||
CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
|
||||
return blob->GetMutableTensor(CPU);
|
||||
return BlobGetMutableTensor(blob, CPU);
|
||||
}
|
||||
|
||||
void shareInputTensor(
|
||||
@ -60,7 +60,7 @@ Predictor::Predictor(PredictorConfig config) : config_(std::move(config)) {
|
||||
for (const auto& name : config_.predict_net->external_input()) {
|
||||
if (!initialized.count(name)) {
|
||||
auto* blob = config_.ws->CreateBlob(name);
|
||||
blob->GetMutableTensor(CPU);
|
||||
BlobGetMutableTensor(blob, CPU);
|
||||
}
|
||||
}
|
||||
CAFFE_ENFORCE(config_.ws->CreateNet(config_.predict_net));
|
||||
|
@ -135,7 +135,7 @@ std::unique_ptr<Blob> randomTensor(
|
||||
const std::vector<int64_t>& dims,
|
||||
CPUContext* ctx) {
|
||||
auto blob = make_unique<Blob>();
|
||||
auto* t = blob->GetMutableTensor(CPU);
|
||||
auto* t = BlobGetMutableTensor(blob.get(), CPU);
|
||||
t->Resize(dims);
|
||||
math::RandUniform<float, CPUContext>(
|
||||
t->size(), -1.0, 1.0, t->template mutable_data<float>(), ctx);
|
||||
@ -180,7 +180,7 @@ TEST_F(PredictorTest, SimpleBatchSized) {
|
||||
auto inputData = randomTensor({1, 4}, ctx_.get());
|
||||
Predictor::TensorList input;
|
||||
input.emplace_back(CPU);
|
||||
auto tensor = inputData->GetMutableTensor(CPU);
|
||||
auto tensor = BlobGetMutableTensor(inputData.get(), CPU);
|
||||
input.back().ResizeLike(*tensor);
|
||||
input.back().ShareData(*tensor);
|
||||
Predictor::TensorList output;
|
||||
@ -196,7 +196,7 @@ TEST_F(PredictorTest, SimpleBatchSizedMapInput) {
|
||||
auto inputData = randomTensor({1, 4}, ctx_.get());
|
||||
Predictor::TensorMap input;
|
||||
auto iter = input.emplace("data", Tensor(CPU));
|
||||
auto tensor = inputData->GetMutableTensor(CPU);
|
||||
auto tensor = BlobGetMutableTensor(inputData.get(), CPU);
|
||||
iter.first->second.ResizeLike(*tensor);
|
||||
iter.first->second.ShareData(*tensor);
|
||||
|
||||
|
@ -328,7 +328,7 @@ void addObjectMethods(py::module& m) {
|
||||
})
|
||||
.def(
|
||||
"tensor",
|
||||
[](Blob* blob) { return py::cast(blob->GetMutableTensor(CPU)); },
|
||||
[](Blob* blob) { return py::cast(BlobGetMutableTensor(blob, CPU)); },
|
||||
py::return_value_policy::reference_internal)
|
||||
.def(
|
||||
"_feed",
|
||||
|
@ -234,7 +234,7 @@ class TensorFeeder : public BlobFeederBase {
|
||||
FeedTensor(
|
||||
option,
|
||||
original_array,
|
||||
blob->GetMutableTensor(Context::GetDeviceType()));
|
||||
BlobGetMutableTensor(blob, Context::GetDeviceType()));
|
||||
}
|
||||
};
|
||||
|
||||
@ -366,31 +366,32 @@ class PythonOpBase : public Operator<Context> {
|
||||
|
||||
// make sure output blob is initialized before creating the binding
|
||||
if (forced_cpu_outputs_.count(i)) {
|
||||
blob->GetMutableTensor(Context::GetDeviceType());
|
||||
BlobGetMutableTensor(blob, Context::GetDeviceType());
|
||||
} else {
|
||||
blob->GetMutableTensor(Context::GetDeviceType());
|
||||
BlobGetMutableTensor(blob, Context::GetDeviceType());
|
||||
}
|
||||
|
||||
py::object py_obj;
|
||||
if (blob->template IsType<Tensor>()) {
|
||||
if (use_dlpack) {
|
||||
DLPackWrapper<CPUContext> wrapper(
|
||||
blob->GetMutableTensor(Context::GetDeviceType()), cpu_option);
|
||||
BlobGetMutableTensor(blob, Context::GetDeviceType()),
|
||||
cpu_option);
|
||||
py_obj = py::cast(wrapper, py::return_value_policy::copy);
|
||||
} else {
|
||||
py_obj = py::cast(
|
||||
blob->GetMutableTensor(Context::GetDeviceType()),
|
||||
BlobGetMutableTensor(blob, Context::GetDeviceType()),
|
||||
py::return_value_policy::reference);
|
||||
}
|
||||
} else {
|
||||
if (use_dlpack) {
|
||||
DLPackWrapper<Context> wrapper(
|
||||
blob->GetMutableTensor(Context::GetDeviceType()),
|
||||
BlobGetMutableTensor(blob, Context::GetDeviceType()),
|
||||
this->device_option());
|
||||
py_obj = py::cast(wrapper, py::return_value_policy::copy);
|
||||
} else {
|
||||
py_obj = py::cast(
|
||||
blob->GetMutableTensor(Context::GetDeviceType()),
|
||||
BlobGetMutableTensor(blob, Context::GetDeviceType()),
|
||||
py::return_value_policy::reference);
|
||||
}
|
||||
}
|
||||
|
@ -163,8 +163,8 @@ public:
|
||||
DeviceOption cpu_option(option);
|
||||
cpu_option.set_device_type(DeviceTypeProto::PROTO_CPU);
|
||||
TensorFeeder<CPUContext> cpu_tensor_feeder;
|
||||
cpu_tensor_feeder.FeedTensor(cpu_option, original_array,
|
||||
blob->GetMutableTensor(CPU));
|
||||
cpu_tensor_feeder.FeedTensor(
|
||||
cpu_option, original_array, BlobGetMutableTensor(blob, CPU));
|
||||
}
|
||||
} catch (ideep::error &e) {
|
||||
LOG(ERROR) << "IDEEP error: " << e.message;
|
||||
|
@ -19,7 +19,7 @@ void AddNoiseInput(
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
|
@ -231,11 +231,12 @@ bool NNPACKConvOp::RunOnDeviceWithOrderNCHW() {
|
||||
(transformedFilterSize + sizeof(float) - 1) / sizeof(float);
|
||||
|
||||
for (auto g = 0; g < group_; g++) {
|
||||
transformedFilters_[g] = ws_->CreateBlob(
|
||||
"__transformed_kernel_" +
|
||||
to_string(__sync_fetch_and_add(
|
||||
&precomputed_transform_id, 1)))
|
||||
->GetMutableTensor(CPU);
|
||||
transformedFilters_[g] = BlobGetMutableTensor(
|
||||
ws_->CreateBlob(
|
||||
"__transformed_kernel_" +
|
||||
to_string(
|
||||
__sync_fetch_and_add(&precomputed_transform_id, 1))),
|
||||
CPU);
|
||||
transformedFilters_[g]->Resize(transformedFilterElements);
|
||||
|
||||
status = nnp_convolution_inference(
|
||||
|
@ -19,7 +19,7 @@ void AddNoiseInput(
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = blob->GetMutableTensor(CPU);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
|
||||
math::RandGaussian<float, CPUContext>(
|
||||
|
@ -26,13 +26,13 @@ TEST(MathROCBLASTest, GemmNoTransNoTrans) {
|
||||
vector<int> shapeX{5, 10};
|
||||
vector<int> shapeW{10, 6};
|
||||
vector<int> shapeY{5, 6};
|
||||
auto* tensorX = blobX->GetMutableTensor(HIP);
|
||||
auto* tensorX = BlobGetMutableTensor(blobX, HIP);
|
||||
tensorX->Resize(shapeX);
|
||||
auto* tensorW = blobW->GetMutableTensor(HIP);
|
||||
auto* tensorW = BlobGetMutableTensor(blobW, HIP);
|
||||
tensorW->Resize(shapeW);
|
||||
auto* tensorY = blobY->GetMutableTensor(HIP);
|
||||
auto* tensorY = BlobGetMutableTensor(blobY, HIP);
|
||||
tensorY->Resize(shapeY);
|
||||
auto* tensorY_host = blobY_host->GetMutableTensor(CPU);
|
||||
auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU);
|
||||
tensorY_host->Resize(shapeY);
|
||||
|
||||
EXPECT_EQ(tensorX->size(), 50);
|
||||
@ -126,13 +126,13 @@ TEST(MathROCBLASTest, GemmNoTransTrans) {
|
||||
vector<int> shapeX{5, 10};
|
||||
vector<int> shapeW{6, 10};
|
||||
vector<int> shapeY{5, 6};
|
||||
auto* tensorX = blobX->GetMutableTensor(HIP);
|
||||
auto* tensorX = BlobGetMutableTensor(blobX, HIP);
|
||||
tensorX->Resize(shapeX);
|
||||
auto* tensorW = blobW->GetMutableTensor(HIP);
|
||||
auto* tensorW = BlobGetMutableTensor(blobW, HIP);
|
||||
tensorW->Resize(shapeW);
|
||||
auto* tensorY = blobY->GetMutableTensor(HIP);
|
||||
auto* tensorY = BlobGetMutableTensor(blobY, HIP);
|
||||
tensorY->Resize(shapeY);
|
||||
auto* tensorY_host = blobY_host->GetMutableTensor(CPU);
|
||||
auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU);
|
||||
tensorY_host->Resize(shapeY);
|
||||
|
||||
EXPECT_EQ(tensorX->size(), 50);
|
||||
@ -225,13 +225,13 @@ TEST(MathROCBLASTest, GemvNoTrans) {
|
||||
vector<int> shapeA{5, 10};
|
||||
vector<int> shapeX{10};
|
||||
vector<int> shapeY{5};
|
||||
auto* tensorA = blobA->GetMutableTensor(HIP);
|
||||
auto* tensorA = BlobGetMutableTensor(blobA, HIP);
|
||||
tensorA->Resize(shapeA);
|
||||
auto* tensorX = blobX->GetMutableTensor(HIP);
|
||||
auto* tensorX = BlobGetMutableTensor(blobX, HIP);
|
||||
tensorX->Resize(shapeX);
|
||||
auto* tensorY = blobY->GetMutableTensor(HIP);
|
||||
auto* tensorY = BlobGetMutableTensor(blobY, HIP);
|
||||
tensorY->Resize(shapeY);
|
||||
auto* tensorY_host = blobY_host->GetMutableTensor(CPU);
|
||||
auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU);
|
||||
tensorY_host->Resize(shapeY);
|
||||
|
||||
EXPECT_EQ(tensorA->size(), 50);
|
||||
@ -315,13 +315,13 @@ TEST(MathROCBLASTest, GemvTrans) {
|
||||
vector<int> shapeA{6, 10};
|
||||
vector<int> shapeX{6};
|
||||
vector<int> shapeY{10};
|
||||
auto* tensorA = blobA->GetMutableTensor(HIP);
|
||||
auto* tensorA = BlobGetMutableTensor(blobA, HIP);
|
||||
tensorA->Resize(shapeA);
|
||||
auto* tensorX = blobX->GetMutableTensor(HIP);
|
||||
auto* tensorX = BlobGetMutableTensor(blobX, HIP);
|
||||
tensorX->Resize(shapeX);
|
||||
auto* tensorY = blobY->GetMutableTensor(HIP);
|
||||
auto* tensorY = BlobGetMutableTensor(blobY, HIP);
|
||||
tensorY->Resize(shapeY);
|
||||
auto* tensorY_host = blobY_host->GetMutableTensor(CPU);
|
||||
auto* tensorY_host = BlobGetMutableTensor(blobY_host, CPU);
|
||||
tensorY_host->Resize(shapeY);
|
||||
|
||||
EXPECT_EQ(tensorA->size(), 60);
|
||||
|
@ -41,9 +41,9 @@ void executeGpuBinaryOpTest(
|
||||
Blob* bloby = ws.CreateBlob("Y");
|
||||
Blob* bloby_host = ws.CreateBlob("Y_host");
|
||||
|
||||
auto* tensorx0 = blobx0->GetMutableTensor(CUDA);
|
||||
auto* tensorx1 = blobx1->GetMutableTensor(CUDA);
|
||||
auto* tensory = bloby->GetMutableTensor(CUDA);
|
||||
auto* tensorx0 = BlobGetMutableTensor(blobx0, CUDA);
|
||||
auto* tensorx1 = BlobGetMutableTensor(blobx1, CUDA);
|
||||
auto* tensory = BlobGetMutableTensor(bloby, CUDA);
|
||||
|
||||
vector<int> shapex0_vector{shapex0};
|
||||
vector<int> shapex1_vector{shapex1};
|
||||
@ -71,7 +71,7 @@ void executeGpuBinaryOpTest(
|
||||
context.FinishDeviceComputation();
|
||||
|
||||
// Copy result to CPU so we can inspect it
|
||||
auto* tensory_host = bloby_host->GetMutableTensor(CPU);
|
||||
auto* tensory_host = BlobGetMutableTensor(bloby_host, CPU);
|
||||
tensory_host->CopyFrom(*tensory, &context);
|
||||
context.FinishDeviceComputation();
|
||||
|
||||
@ -94,7 +94,7 @@ TEST(MathUtilGPUTest, testAddStripedBatch) {
|
||||
vector<int> shapex{33 * 9, 25};
|
||||
vector<int> shapey{33, 25};
|
||||
|
||||
auto* tensorx = blobx->GetMutableTensor(CUDA);
|
||||
auto* tensorx = BlobGetMutableTensor(blobx, CUDA);
|
||||
tensorx->Resize(shapex);
|
||||
int stripe = 33 * 25;
|
||||
vector<float> tot(33, 0.0);
|
||||
@ -110,7 +110,7 @@ TEST(MathUtilGPUTest, testAddStripedBatch) {
|
||||
}
|
||||
}
|
||||
|
||||
auto* tensory = bloby->GetMutableTensor(CUDA);
|
||||
auto* tensory = BlobGetMutableTensor(bloby, CUDA);
|
||||
tensory->Resize(shapey);
|
||||
math::Set<float, CUDAContext>(
|
||||
stripe, 0.0, tensory->mutable_data<float>(), &context);
|
||||
@ -125,7 +125,7 @@ TEST(MathUtilGPUTest, testAddStripedBatch) {
|
||||
context.FinishDeviceComputation();
|
||||
|
||||
// Copy result to CPU so we can inspect it
|
||||
auto* tensory_host = bloby_host->GetMutableTensor(CPU);
|
||||
auto* tensory_host = BlobGetMutableTensor(bloby_host, CPU);
|
||||
tensory_host->CopyFrom(*tensory, &context);
|
||||
context.FinishDeviceComputation();
|
||||
|
||||
@ -258,9 +258,9 @@ class GemmBatchedGPUTest
|
||||
Blob* X_blob = ws_.CreateBlob("X");
|
||||
Blob* W_blob = ws_.CreateBlob("W");
|
||||
Blob* Y_blob = ws_.CreateBlob("Y");
|
||||
X_ = X_blob->GetMutableTensor(CUDA);
|
||||
W_ = W_blob->GetMutableTensor(CUDA);
|
||||
Y_ = Y_blob->GetMutableTensor(CUDA);
|
||||
X_ = BlobGetMutableTensor(X_blob, CUDA);
|
||||
W_ = BlobGetMutableTensor(W_blob, CUDA);
|
||||
Y_ = BlobGetMutableTensor(Y_blob, CUDA);
|
||||
X_->Resize(std::vector<int64_t>{3, 5, 10});
|
||||
W_->Resize(std::vector<int64_t>{3, 6, 10});
|
||||
Y_->Resize(std::vector<int64_t>{3, 5, 6});
|
||||
@ -381,8 +381,8 @@ class ReduceTensorGPUTest : public testing::Test {
|
||||
cuda_context_ = make_unique<CUDAContext>(option_);
|
||||
Blob* blob_x = ws_.CreateBlob("X");
|
||||
Blob* blob_y = ws_.CreateBlob("Y");
|
||||
X_ = blob_x->GetMutableTensor(CUDA);
|
||||
Y_ = blob_y->GetMutableTensor(CUDA);
|
||||
X_ = BlobGetMutableTensor(blob_x, CUDA);
|
||||
Y_ = BlobGetMutableTensor(blob_y, CUDA);
|
||||
}
|
||||
|
||||
void SetUpData(
|
||||
@ -402,7 +402,7 @@ class ReduceTensorGPUTest : public testing::Test {
|
||||
|
||||
void VerifyResult(const std::vector<float>& expected_output) {
|
||||
Blob* blob_y_host = ws_.CreateBlob("Y_host");
|
||||
auto* Y_host = blob_y_host->GetMutableTensor(CPU);
|
||||
auto* Y_host = BlobGetMutableTensor(blob_y_host, CPU);
|
||||
Y_host->CopyFrom(*Y_, cuda_context_.get());
|
||||
cuda_context_->FinishDeviceComputation();
|
||||
ASSERT_EQ(expected_output.size(), Y_host->size());
|
||||
@ -664,8 +664,8 @@ class BroadcastGPUTest : public testing::Test {
|
||||
cuda_context_ = make_unique<CUDAContext>(option_);
|
||||
Blob* blob_x = ws_.CreateBlob("X");
|
||||
Blob* blob_y = ws_.CreateBlob("Y");
|
||||
X_ = blob_x->GetMutableTensor(CUDA);
|
||||
Y_ = blob_y->GetMutableTensor(CUDA);
|
||||
X_ = BlobGetMutableTensor(blob_x, CUDA);
|
||||
Y_ = BlobGetMutableTensor(blob_y, CUDA);
|
||||
}
|
||||
|
||||
void SetUpData(
|
||||
@ -681,7 +681,7 @@ class BroadcastGPUTest : public testing::Test {
|
||||
|
||||
void VerifyResult(const std::vector<float>& expected_output) {
|
||||
Blob* blob_y_host = ws_.CreateBlob("Y_host");
|
||||
auto* Y_host = blob_y_host->GetMutableTensor(CPU);
|
||||
auto* Y_host = BlobGetMutableTensor(blob_y_host, CPU);
|
||||
Y_host->CopyFrom(*Y_, cuda_context_.get());
|
||||
cuda_context_->FinishDeviceComputation();
|
||||
ASSERT_EQ(expected_output.size(), Y_host->size());
|
||||
@ -741,9 +741,9 @@ class MomentsGPUTest : public testing::Test {
|
||||
Blob* blob_x = ws_.CreateBlob("X");
|
||||
Blob* blob_mean = ws_.CreateBlob("mean");
|
||||
Blob* blob_variance = ws_.CreateBlob("variance");
|
||||
X_ = blob_x->GetMutableTensor(CUDA);
|
||||
mean_ = blob_mean->GetMutableTensor(CUDA);
|
||||
variance_ = blob_variance->GetMutableTensor(CUDA);
|
||||
X_ = BlobGetMutableTensor(blob_x, CUDA);
|
||||
mean_ = BlobGetMutableTensor(blob_mean, CUDA);
|
||||
variance_ = BlobGetMutableTensor(blob_variance, CUDA);
|
||||
}
|
||||
|
||||
void SetUpData(
|
||||
@ -766,10 +766,10 @@ class MomentsGPUTest : public testing::Test {
|
||||
const std::vector<float>& mean_data,
|
||||
const std::vector<float>& variance_data) {
|
||||
Blob* blob_mean_host = ws_.CreateBlob("mean_host");
|
||||
auto* mean_host = blob_mean_host->GetMutableTensor(CPU);
|
||||
auto* mean_host = BlobGetMutableTensor(blob_mean_host, CPU);
|
||||
mean_host->CopyFrom(*mean_, cuda_context_.get());
|
||||
Blob* blob_variance_host = ws_.CreateBlob("variance_host");
|
||||
auto* variance_host = blob_variance_host->GetMutableTensor(CPU);
|
||||
auto* variance_host = BlobGetMutableTensor(blob_variance_host, CPU);
|
||||
variance_host->CopyFrom(*variance_, cuda_context_.get());
|
||||
cuda_context_->FinishDeviceComputation();
|
||||
|
||||
@ -868,8 +868,8 @@ class TransposeGPUTest : public testing::Test {
|
||||
cuda_context_ = make_unique<CUDAContext>(option_);
|
||||
Blob* blob_x = ws_.CreateBlob("X");
|
||||
Blob* blob_y = ws_.CreateBlob("Y");
|
||||
X_ = blob_x->GetMutableTensor(CUDA);
|
||||
Y_ = blob_y->GetMutableTensor(CUDA);
|
||||
X_ = BlobGetMutableTensor(blob_x, CUDA);
|
||||
Y_ = BlobGetMutableTensor(blob_y, CUDA);
|
||||
}
|
||||
|
||||
void SetUpData(
|
||||
@ -890,7 +890,7 @@ class TransposeGPUTest : public testing::Test {
|
||||
|
||||
void VerifyResult(const std::vector<float>& expected_output) {
|
||||
Blob* blob_y_host = ws_.CreateBlob("Y_host");
|
||||
auto* Y_host = blob_y_host->GetMutableTensor(CPU);
|
||||
auto* Y_host = BlobGetMutableTensor(blob_y_host, CPU);
|
||||
Y_host->CopyFrom(*Y_, cuda_context_.get());
|
||||
cuda_context_->FinishDeviceComputation();
|
||||
ASSERT_EQ(expected_output.size(), Y_host->size());
|
||||
|
Reference in New Issue
Block a user