mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 08:00:58 +08:00
Compare commits
1 Commits
whc/stage2
...
mlazos/bas
| Author | SHA1 | Date | |
|---|---|---|---|
| a9cb7a187b |
@ -149,21 +149,6 @@ case "$image" in
|
||||
TRITON=yes
|
||||
INDUCTOR_BENCHMARKS=yes
|
||||
;;
|
||||
pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks)
|
||||
CUDA_VERSION=12.1.1
|
||||
CUDNN_VERSION=8
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=9
|
||||
PROTOBUF=yes
|
||||
DB=yes
|
||||
VISION=yes
|
||||
KATEX=yes
|
||||
UCX_COMMIT=${_UCX_COMMIT}
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
CONDA_CMAKE=yes
|
||||
TRITON=yes
|
||||
INDUCTOR_BENCHMARKS=yes
|
||||
;;
|
||||
pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9)
|
||||
CUDA_VERSION=11.8.0
|
||||
CUDNN_VERSION=8
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
set -euo pipefail
|
||||
|
||||
readonly version=v24.04
|
||||
readonly version=v23.08
|
||||
readonly src_host=https://review.mlplatform.org/ml
|
||||
readonly src_repo=ComputeLibrary
|
||||
|
||||
|
||||
@ -1278,10 +1278,6 @@ elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHAR
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo* && $SHARD_NUMBER -gt 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
|
||||
install_torchvision
|
||||
test_dynamo_shard "${SHARD_NUMBER}"
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *rocm* && -n "$TESTS_TO_INCLUDE" ]]; then
|
||||
install_torchvision
|
||||
test_python_shard "$SHARD_NUMBER"
|
||||
test_aten
|
||||
elif [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
|
||||
test_without_numpy
|
||||
install_torchvision
|
||||
@ -1311,6 +1307,10 @@ elif [[ "${BUILD_ENVIRONMENT}" == *-mobile-lightweight-dispatch* ]]; then
|
||||
test_libtorch
|
||||
elif [[ "${TEST_CONFIG}" = docs_test ]]; then
|
||||
test_docs_test
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *rocm* && -n "$TESTS_TO_INCLUDE" ]]; then
|
||||
install_torchvision
|
||||
test_python
|
||||
test_aten
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *xpu* ]]; then
|
||||
install_torchvision
|
||||
test_python
|
||||
|
||||
1
.github/workflows/docker-builds.yml
vendored
1
.github/workflows/docker-builds.yml
vendored
@ -42,7 +42,6 @@ jobs:
|
||||
pytorch-linux-focal-cuda12.4-cudnn8-py3-gcc9-inductor-benchmarks,
|
||||
pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9,
|
||||
pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks,
|
||||
pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks,
|
||||
pytorch-linux-focal-cuda11.8-cudnn8-py3-gcc9,
|
||||
pytorch-linux-focal-py3.8-clang10,
|
||||
pytorch-linux-focal-py3.11-clang10,
|
||||
|
||||
23
.github/workflows/inductor.yml
vendored
23
.github/workflows/inductor.yml
vendored
@ -107,27 +107,6 @@ jobs:
|
||||
secrets:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
|
||||
linux-focal-cuda12_1-py3_12-gcc9-inductor-build:
|
||||
name: cuda12.1-py3.12-gcc9-sm86
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86
|
||||
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3.12-gcc9-inductor-benchmarks
|
||||
cuda-arch-list: '8.6'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
]}
|
||||
|
||||
linux-focal-cuda12_1-py3_12-gcc9-inductor-test:
|
||||
name: cuda12.1-py3.12-gcc9-sm86
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: linux-focal-cuda12_1-py3_12-gcc9-inductor-build
|
||||
with:
|
||||
build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86
|
||||
docker-image: ${{ needs.linux-focal-cuda12_1-py3_12-gcc9-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_1-py3_12-gcc9-inductor-build.outputs.test-matrix }}
|
||||
|
||||
linux-jammy-cpu-py3_8-gcc11-inductor-build:
|
||||
name: linux-jammy-cpu-py3.8-gcc11-inductor
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
@ -146,7 +125,7 @@ jobs:
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "linux.12xlarge" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" },
|
||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.24xl.spr-metal" },
|
||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
]}
|
||||
secrets:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
|
||||
6
.github/workflows/trunk.yml
vendored
6
.github/workflows/trunk.yml
vendored
@ -192,9 +192,7 @@ jobs:
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" },
|
||||
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" },
|
||||
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" },
|
||||
{ config: "default", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" },
|
||||
]}
|
||||
|
||||
linux-focal-rocm6_1-py3_8-test:
|
||||
@ -210,4 +208,4 @@ jobs:
|
||||
build-environment: linux-focal-rocm6.1-py3.8
|
||||
docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }}
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
|
||||
|
||||
@ -1,43 +0,0 @@
|
||||
name: Upload test stats intermediate
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
workflow_id:
|
||||
description: workflow_id of the run
|
||||
required: true
|
||||
workflow_run_attempt:
|
||||
description: workflow_run_attempt of the run
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
intermediate_upload_test_stats:
|
||||
name: Intermediate upload test stats for ${{ inputs.workflow_id }}
|
||||
runs-on: ubuntu-22.04
|
||||
environment: upload-stats
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: false
|
||||
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: pip
|
||||
|
||||
- run: |
|
||||
pip3 install requests==2.26 rockset==1.0.3 boto3==1.19.12
|
||||
|
||||
- name: Upload test stats
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
WORKFLOW_RUN_ID: ${{ inputs.workflow_id }}
|
||||
WORKFLOW_RUN_ATTEMPT: ${{ inputs.workflow_run_attempt }}
|
||||
run: |
|
||||
python3 -m tools.stats.upload_test_stats_intermediate \
|
||||
--workflow-run-id "${WORKFLOW_RUN_ID}" \
|
||||
--workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" \
|
||||
12
.gitmodules
vendored
12
.gitmodules
vendored
@ -2,6 +2,10 @@
|
||||
ignore = dirty
|
||||
path = third_party/pybind11
|
||||
url = https://github.com/pybind/pybind11.git
|
||||
[submodule "third_party/cub"]
|
||||
ignore = dirty
|
||||
path = third_party/cub
|
||||
url = https://github.com/NVlabs/cub.git
|
||||
[submodule "third_party/eigen"]
|
||||
ignore = dirty
|
||||
path = third_party/eigen
|
||||
@ -46,6 +50,10 @@
|
||||
ignore = dirty
|
||||
path = third_party/psimd
|
||||
url = https://github.com/Maratyszcza/psimd.git
|
||||
[submodule "third_party/zstd"]
|
||||
ignore = dirty
|
||||
path = third_party/zstd
|
||||
url = https://github.com/facebook/zstd.git
|
||||
[submodule "third_party/cpuinfo"]
|
||||
ignore = dirty
|
||||
path = third_party/cpuinfo
|
||||
@ -144,7 +152,3 @@
|
||||
[submodule "third_party/opentelemetry-cpp"]
|
||||
path = third_party/opentelemetry-cpp
|
||||
url = https://github.com/open-telemetry/opentelemetry-cpp.git
|
||||
[submodule "third_party/cpp-httplib"]
|
||||
path = third_party/cpp-httplib
|
||||
url = https://github.com/yhirose/cpp-httplib.git
|
||||
branch = v0.15.3
|
||||
|
||||
@ -1052,12 +1052,6 @@ exclude_patterns = [
|
||||
'test/quantization/fx/test_numeric_suite_fx.py',
|
||||
'test/quantization/fx/test_quantize_fx.py',
|
||||
'test/quantization/fx/test_subgraph_rewriter.py',
|
||||
'test/test_fake_tensor.py',
|
||||
'test/test_flop_counter.py',
|
||||
'test/test_function_schema.py',
|
||||
'test/test_functional_autograd_benchmark.py',
|
||||
'test/test_functional_optim.py',
|
||||
'test/test_functionalization_of_rng_ops.py',
|
||||
'test/test_datapipe.py',
|
||||
'test/test_futures.py',
|
||||
'test/test_fx.py',
|
||||
@ -1143,6 +1137,7 @@ exclude_patterns = [
|
||||
'test/test_transformers.py',
|
||||
'test/test_type_promotion.py',
|
||||
'test/test_unary_ufuncs.py',
|
||||
'test/test_utils.py',
|
||||
'test/test_vulkan.py',
|
||||
'test/test_xnnpack_integration.py',
|
||||
'test/torch_np/numpy_test/**/*.py',
|
||||
@ -1929,6 +1924,8 @@ exclude_patterns = [
|
||||
'torch/utils/_mode_utils.py',
|
||||
'torch/utils/_python_dispatch.py',
|
||||
'torch/utils/_stats.py',
|
||||
'torch/utils/_sympy/__init__.py',
|
||||
'torch/utils/_sympy/functions.py',
|
||||
'torch/utils/_traceback.py',
|
||||
'torch/utils/_zip.py',
|
||||
'torch/utils/backcompat/__init__.py',
|
||||
|
||||
@ -663,7 +663,6 @@ cu_library(
|
||||
name = "torch_cuda",
|
||||
srcs = [
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
],
|
||||
copts = torch_cuda_half_options,
|
||||
@ -772,7 +771,7 @@ cc_library(
|
||||
[
|
||||
"torch/*.h",
|
||||
"torch/csrc/**/*.h",
|
||||
"torch/csrc/distributed/c10d/**/*.hpp",
|
||||
"torch/csrc/distributed/c10d/*.hpp",
|
||||
"torch/lib/libshm/*.h",
|
||||
],
|
||||
exclude = [
|
||||
@ -831,7 +830,6 @@ cc_library(
|
||||
"torch/csrc/cuda/python_nccl.cpp",
|
||||
"torch/csrc/cuda/nccl.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
],
|
||||
)) + torch_sources,
|
||||
|
||||
@ -279,13 +279,11 @@ endif()
|
||||
option(USE_SLEEF_FOR_ARM_VEC256 "Use sleef for arm" OFF)
|
||||
option(USE_SOURCE_DEBUG_ON_MOBILE "Enable" ON)
|
||||
option(USE_LITE_INTERPRETER_PROFILER "Enable" ON)
|
||||
cmake_dependent_option(
|
||||
USE_LITE_AOTI "Include AOTI sources" OFF
|
||||
"BUILD_LITE_INTERPRETER" OFF)
|
||||
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
|
||||
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
|
||||
# option USE_XNNPACK: try to enable xnnpack by default.
|
||||
option(USE_XNNPACK "Use XNNPACK" ON)
|
||||
option(USE_ZSTD "Use ZSTD" OFF)
|
||||
option(USE_ROCM_KERNEL_ASSERT "Use Kernel Assert for ROCm" OFF)
|
||||
# Ensure that an ITT build is the default for x86 CPUs
|
||||
cmake_dependent_option(
|
||||
@ -415,6 +413,7 @@ option(USE_SYSTEM_FXDIV "Use system-provided fxdiv." OFF)
|
||||
option(USE_SYSTEM_BENCHMARK "Use system-provided google benchmark." OFF)
|
||||
option(USE_SYSTEM_ONNX "Use system-provided onnx." OFF)
|
||||
option(USE_SYSTEM_XNNPACK "Use system-provided xnnpack." OFF)
|
||||
option(USE_SYSTEM_ZSTD "Use system-provided zstd." OFF)
|
||||
option(USE_GOLD_LINKER "Use ld.gold to link" OFF)
|
||||
if(USE_SYSTEM_LIBS)
|
||||
set(USE_SYSTEM_CPUINFO ON)
|
||||
@ -436,6 +435,9 @@ if(USE_SYSTEM_LIBS)
|
||||
if(USE_TBB)
|
||||
set(USE_SYSTEM_TBB ON)
|
||||
endif()
|
||||
if(USE_ZSTD)
|
||||
set(USE_SYSTEM_ZSTD ON)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Used when building Caffe2 through setup.py
|
||||
|
||||
@ -887,12 +887,12 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::create(
|
||||
}
|
||||
|
||||
IValue IValue::deepcopy(std::optional<at::Device> device) const {
|
||||
IValue::HashIdentityIValueMap memo;
|
||||
IValue::HashAliasedIValueMap memo;
|
||||
return deepcopy(memo, device);
|
||||
}
|
||||
|
||||
IValue IValue::deepcopy(
|
||||
IValue::HashIdentityIValueMap& memo,
|
||||
IValue::HashAliasedIValueMap& memo,
|
||||
std::optional<at::Device> device) const {
|
||||
if (memo.count(*this)) {
|
||||
return memo.at(*this);
|
||||
@ -1028,12 +1028,12 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::copy_to_weak_compilation_ref(
|
||||
|
||||
c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
|
||||
std::optional<at::Device> device) const {
|
||||
IValue::HashIdentityIValueMap memo;
|
||||
IValue::HashAliasedIValueMap memo;
|
||||
return deepcopy(memo, device);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
|
||||
IValue::HashIdentityIValueMap& memo,
|
||||
IValue::HashAliasedIValueMap& memo,
|
||||
std::optional<at::Device> device) const {
|
||||
auto cu = type_.cu_;
|
||||
auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes());
|
||||
|
||||
@ -1117,23 +1117,6 @@ struct TORCH_API IValue final {
|
||||
using HashAliasedIValueMap =
|
||||
std::unordered_map<IValue, IValue, HashAliasedIValue, CompAliasedIValues>;
|
||||
|
||||
struct HashIdentityIValue {
|
||||
size_t operator()(const IValue& val) const {
|
||||
return val.payload.u.as_int;
|
||||
}
|
||||
};
|
||||
|
||||
struct CompIdentityIValues {
|
||||
bool operator()(const IValue& lhs, const IValue& rhs) const {
|
||||
return lhs.is(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
using HashIdentityIValues =
|
||||
std::unordered_set<IValue, HashIdentityIValue, CompIdentityIValues>;
|
||||
using HashIdentityIValueMap =
|
||||
std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>;
|
||||
|
||||
// Chechs if this and rhs has a subvalues in common.
|
||||
// [t1,t2] and [t2, t3] returns true.
|
||||
bool overlaps(const IValue& rhs) const;
|
||||
@ -1147,7 +1130,7 @@ struct TORCH_API IValue final {
|
||||
void visit(const std::function<bool(const IValue&)>& visitor) const;
|
||||
IValue deepcopy(std::optional<at::Device> device = c10::nullopt) const;
|
||||
IValue deepcopy(
|
||||
HashIdentityIValueMap& memo,
|
||||
HashAliasedIValueMap& memo,
|
||||
std::optional<at::Device> device = c10::nullopt) const;
|
||||
|
||||
private:
|
||||
|
||||
@ -1589,7 +1589,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
||||
std::optional<at::Device> device = c10::nullopt) const;
|
||||
|
||||
c10::intrusive_ptr<Object> deepcopy(
|
||||
IValue::HashIdentityIValueMap& memo,
|
||||
IValue::HashAliasedIValueMap& memo,
|
||||
std::optional<at::Device> device = c10::nullopt) const;
|
||||
|
||||
bool is_weak_compilation_ref() const {
|
||||
|
||||
@ -1422,13 +1422,10 @@ void scaled_gemm(
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
||||
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60200)
|
||||
// Amax support in ROCm as of 6.2
|
||||
if (isFloat8Type(result_dtype)) {
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr);
|
||||
}
|
||||
#endif
|
||||
#ifndef USE_ROCM
|
||||
if (isFloat8Type(result_dtype)) {
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr);
|
||||
}
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
|
||||
#endif
|
||||
CuBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't');
|
||||
|
||||
@ -215,87 +215,6 @@ static inline float16_t reduce(float16x8_t x) {
|
||||
return reduce(vadd_f16(vget_low_f16(x), vget_high_f16(x)));
|
||||
}
|
||||
|
||||
/*
|
||||
* The below reduce overload and
|
||||
* fp16_gemv_trans_fp16_arith_by_dot_products function is adapted from
|
||||
* llama.cpp's ggml_vec_dot_f16 and surrounding utility functions, so
|
||||
* here is the required copyright notice:
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2023-2024 The ggml authors
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*/
|
||||
#define F16_ELEMENTS_PER_ITERATION 32
|
||||
#define F16_ELEMENTS_PER_REGISTER 8
|
||||
#define F16_REGISTERS_PER_ITERATION (F16_ELEMENTS_PER_ITERATION / F16_ELEMENTS_PER_REGISTER)
|
||||
static inline double reduce(float16x8_t x[F16_REGISTERS_PER_ITERATION]) {
|
||||
int offset = F16_REGISTERS_PER_ITERATION / 2;
|
||||
for (int i = 0; i < offset; ++i) {
|
||||
x[i] = vaddq_f16(x[i], x[offset + i]);
|
||||
}
|
||||
offset /= 2;
|
||||
for (int i = 0; i < offset; ++i) {
|
||||
x[i] = vaddq_f16(x[i], x[offset + i]);
|
||||
}
|
||||
offset /= 2;
|
||||
for (int i = 0; i < offset; ++i) {
|
||||
x[i] = vaddq_f16(x[i], x[offset + i]);
|
||||
}
|
||||
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16(x[0]));
|
||||
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0]));
|
||||
return (double)vaddvq_f32(vaddq_f32(t0, t1));
|
||||
|
||||
}
|
||||
|
||||
static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) {
|
||||
#ifdef __ARM_FEATURE_FMA
|
||||
return vfmaq_f16(a, b, c);
|
||||
#else
|
||||
return vaddq_f16(a, vmulq_f16(b, c));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Rather than unrolling to process multiple rows (transposed columns)
|
||||
// of matrix A at once as done in fp16_gemv_trans_fp16_arith, unroll
|
||||
// along an individual dot product.
|
||||
static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) {
|
||||
parallel_for(0, n, 1, [&](int begin, int end) {
|
||||
for (int i = begin; i < end; ++i) {
|
||||
float16x8_t sum[F16_REGISTERS_PER_ITERATION] = {vdupq_n_f16(0)};
|
||||
float16x8_t ax[F16_REGISTERS_PER_ITERATION];
|
||||
float16x8_t ay[F16_REGISTERS_PER_ITERATION];
|
||||
|
||||
for (int j = 0; j < m; j += F16_ELEMENTS_PER_ITERATION) {
|
||||
for (int k = 0; k < F16_REGISTERS_PER_ITERATION; ++k) {
|
||||
ax[k] = vld1q_f16(x + j + k * F16_ELEMENTS_PER_REGISTER);
|
||||
ay[k] = vld1q_f16(a + lda * i + j + k * F16_ELEMENTS_PER_REGISTER);
|
||||
sum[k] = f16_fma(sum[k], ax[k], ay[k]);
|
||||
}
|
||||
}
|
||||
// TODO: add a tail fixup so we don't have to have such a
|
||||
// restrictive gate to enter this path.
|
||||
y[i * incy] = reduce(sum);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void fp16_gemv_trans_fp16_arith(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) {
|
||||
parallel_for(0, n / 4, 1, [&](int begin, int end) {
|
||||
@ -311,13 +230,13 @@ static void fp16_gemv_trans_fp16_arith(const int m, const int n, const float16_t
|
||||
for (auto j = 0; j < m; j += 8) {
|
||||
float16x8_t xVec = vld1q_f16(x + j);
|
||||
float16x8_t a0Vec = vld1q_f16(row0 + j);
|
||||
sum0Vec = f16_fma(sum0Vec, a0Vec, xVec);
|
||||
sum0Vec = vaddq_f16(sum0Vec, vmulq_f16(a0Vec, xVec));
|
||||
float16x8_t a1Vec = vld1q_f16(row1 + j);
|
||||
sum1Vec = f16_fma(sum1Vec, a1Vec, xVec);
|
||||
sum1Vec = vaddq_f16(sum1Vec, vmulq_f16(a1Vec, xVec));
|
||||
float16x8_t a2Vec = vld1q_f16(row2 + j);
|
||||
sum2Vec = f16_fma(sum2Vec, a2Vec, xVec);
|
||||
sum2Vec = vaddq_f16(sum2Vec, vmulq_f16(a2Vec, xVec));
|
||||
float16x8_t a3Vec = vld1q_f16(row3 + j);
|
||||
sum3Vec = f16_fma(sum3Vec, a3Vec, xVec);
|
||||
sum3Vec = vaddq_f16(sum3Vec, vmulq_f16(a3Vec, xVec));
|
||||
}
|
||||
y[(i + 0) * incy] = reduce(sum0Vec);
|
||||
y[(i + 1) * incy] = reduce(sum1Vec);
|
||||
@ -326,7 +245,6 @@ static void fp16_gemv_trans_fp16_arith(const int m, const int n, const float16_t
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
static inline float reduce(float32x4_t x) {
|
||||
@ -334,14 +252,6 @@ static inline float reduce(float32x4_t x) {
|
||||
return vgetq_lane_f32(vpaddq_f32(sum, sum), 0);
|
||||
}
|
||||
|
||||
static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
|
||||
#ifdef __ARM_FEATURE_FMA
|
||||
return vfmaq_f32(a, b, c);
|
||||
#else
|
||||
return vaddq_f32(a, vmulq_f32(b, c));
|
||||
#endif
|
||||
}
|
||||
|
||||
static void fp16_gemv_trans_fp32_arith(const int m, const int n, const float16_t* a, const int lda, const float16_t *x, float16_t* y, int incy) {
|
||||
parallel_for(0, n / 4, 1, [&](int begin, int end) {
|
||||
for (auto i = begin * 4 ; i < end * 4; i += 4) {
|
||||
@ -356,13 +266,13 @@ static void fp16_gemv_trans_fp32_arith(const int m, const int n, const float16_t
|
||||
for (auto j = 0; j < m; j += 4) {
|
||||
float32x4_t xVec = vcvt_f32_f16(vld1_f16(x + j));
|
||||
float32x4_t a0Vec = vcvt_f32_f16(vld1_f16(row0 + j));
|
||||
sum0Vec = f32_fma(sum0Vec, a0Vec, xVec);
|
||||
sum0Vec = vaddq_f32(sum0Vec, vmulq_f32(a0Vec, xVec));
|
||||
float32x4_t a1Vec = vcvt_f32_f16(vld1_f16(row1 + j));
|
||||
sum1Vec = f32_fma(sum1Vec, a1Vec, xVec);
|
||||
sum1Vec = vaddq_f32(sum1Vec, vmulq_f32(a1Vec, xVec));
|
||||
float32x4_t a2Vec = vcvt_f32_f16(vld1_f16(row2 + j));
|
||||
sum2Vec = f32_fma(sum2Vec, a2Vec, xVec);
|
||||
sum2Vec = vaddq_f32(sum2Vec, vmulq_f32(a2Vec, xVec));
|
||||
float32x4_t a3Vec = vcvt_f32_f16(vld1_f16(row3 + j));
|
||||
sum3Vec = f32_fma(sum3Vec, a3Vec, xVec);
|
||||
sum3Vec = vaddq_f32(sum3Vec, vmulq_f32(a3Vec, xVec));
|
||||
}
|
||||
y[(i + 0) * incy] = reduce(sum0Vec);
|
||||
y[(i + 1) * incy] = reduce(sum1Vec);
|
||||
@ -385,16 +295,11 @@ void fp16_gemv_trans(
|
||||
const int incy) {
|
||||
if (incx == 1 && alpha == 1.0 && beta == 0.0 && m % 4 == 0 && n % 4 == 0) {
|
||||
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
||||
if (at::globalContext().allowFP16ReductionCPU()) {
|
||||
if (m % 32 == 0 && n % 32 == 0) {
|
||||
return fp16_gemv_trans_fp16_arith_by_dot_products(m, n, a, lda, x, y, incy);
|
||||
}
|
||||
if (m % 8 == 0) {
|
||||
return fp16_gemv_trans_fp16_arith(m, n, a, lda, x, y, incy);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return at::globalContext().allowFP16ReductionCPU() && m % 8 == 0 ? fp16_gemv_trans_fp16_arith(m, n, a, lda, x, y, incy)
|
||||
: fp16_gemv_trans_fp32_arith(m, n, a, lda, x, y, incy);
|
||||
#else
|
||||
return fp16_gemv_trans_fp32_arith(m, n, a, lda, x, y, incy);
|
||||
#endif
|
||||
}
|
||||
for (const auto i : c10::irange(n)) {
|
||||
float sum = 0;
|
||||
|
||||
@ -543,11 +543,6 @@ Tensor& slow_conv2d_forward_out_cpu(
|
||||
IntArrayRef padding,
|
||||
Tensor& output) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
|
||||
TORCH_CHECK(kernel_size.size() == 2, "2D kernel_size expected");
|
||||
TORCH_CHECK(stride.size() == 2, "2D stride expected");
|
||||
TORCH_CHECK(padding.size() == 2, "2D padding expected");
|
||||
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
const Tensor& bias = *bias_maybe_owned;
|
||||
|
||||
|
||||
@ -1,59 +0,0 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/FusedAdagrad.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_fused_adagrad.h>
|
||||
#include <ATen/ops/_fused_adagrad_native.h>
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
|
||||
void _fused_adagrad_kernel_cpu_(
|
||||
at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList state_sums,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double lr_decay,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
const float* grad_scale_ptr =
|
||||
grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr;
|
||||
const float* found_inf_ptr =
|
||||
found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr;
|
||||
if (found_inf_ptr && *found_inf_ptr == 1.0) {
|
||||
return;
|
||||
}
|
||||
size_t n_tensors = params.size();
|
||||
TORCH_CHECK(grads.size() == n_tensors);
|
||||
TORCH_CHECK(state_sums.size() == n_tensors);
|
||||
TORCH_CHECK(state_steps.size() == n_tensors);
|
||||
for (size_t i = 0; i < n_tensors; i++){
|
||||
fused_adagrad_stub(
|
||||
kCPU,
|
||||
params[i],
|
||||
grads[i],
|
||||
state_sums[i],
|
||||
state_steps[i],
|
||||
lr,
|
||||
lr_decay,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
DEFINE_DISPATCH(fused_adagrad_stub);
|
||||
|
||||
}
|
||||
}
|
||||
@ -1,23 +0,0 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
|
||||
using fused_adagrad_fn = void (*)(
|
||||
const at::Tensor& param,
|
||||
const at::Tensor& grad,
|
||||
const at::Tensor& state_sum,
|
||||
const at::Tensor& state_step,
|
||||
const double lr,
|
||||
const double lr_decay,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const float* grad_scale_ptr);
|
||||
|
||||
DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub);
|
||||
|
||||
}
|
||||
}
|
||||
@ -8,7 +8,6 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_aminmax_native.h>
|
||||
#include <ATen/ops/aminmax.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/max.h>
|
||||
@ -66,11 +65,4 @@ Tensor& max_unary_out(const Tensor &self, Tensor& out) {
|
||||
return out;
|
||||
}
|
||||
|
||||
// DEPRECATED: Use at::aminmax instead
|
||||
std::tuple<Tensor, Tensor> _aminmax_all(const Tensor &self) {
|
||||
TORCH_WARN_ONCE("_aminmax is deprecated as of PyTorch 1.11 and will be removed in a future release. Use aminmax instead."
|
||||
" This warning will only appear once per process.");
|
||||
return at::aminmax(self);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -20,7 +20,6 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_aminmax_native.h>
|
||||
#include <ATen/ops/_assert_async_native.h>
|
||||
#include <ATen/ops/_functional_assert_async_native.h>
|
||||
#include <ATen/ops/_print_native.h>
|
||||
@ -682,13 +681,6 @@ std::tuple<Tensor, Tensor> qmin(const Tensor& self, int64_t dim, bool keepdim) {
|
||||
at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices);
|
||||
}
|
||||
|
||||
// DEPRECATED: Use at::aminmax instead
|
||||
std::tuple<Tensor, Tensor> _aminmax(const Tensor& self, int64_t dim, bool keepdim) {
|
||||
TORCH_WARN_ONCE("_aminmax is deprecated as of PyTorch 1.11 and will be removed in a future release. Use aminmax instead."
|
||||
" This warning will only appear once per process.");
|
||||
return at::aminmax(self, dim, keepdim);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(clamp_out)
|
||||
(
|
||||
const Tensor& /*self*/,
|
||||
|
||||
@ -191,8 +191,8 @@ ScalarType result_type(const Scalar& scalar1, const Scalar& scalar2) {
|
||||
return result_type(state);
|
||||
}
|
||||
|
||||
bool can_cast(const at::ScalarType from_, const at::ScalarType to) {
|
||||
return at::canCast(from_, to);
|
||||
bool can_cast(const at::ScalarType from, const at::ScalarType to) {
|
||||
return at::canCast(from, to);
|
||||
}
|
||||
|
||||
ScalarType promote_types(ScalarType type1, ScalarType type2) {
|
||||
|
||||
@ -1,225 +0,0 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/FusedAdagrad.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
namespace at::native {
|
||||
|
||||
namespace{
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
typename std::enable_if<
|
||||
std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
|
||||
void>::
|
||||
type inline adagrad_math(
|
||||
scalar_t* param_ptr,
|
||||
scalar_t* grad_ptr,
|
||||
scalar_t* state_sum_ptr,
|
||||
const double clr,
|
||||
const double eps,
|
||||
const double weight_decay,
|
||||
const bool maximize,
|
||||
const float* grad_scale_ptr,
|
||||
int64_t size
|
||||
){
|
||||
using lpVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<opmath_t>;
|
||||
lpVec grad_vec_to_store;
|
||||
fVec param_vec1, param_vec2;
|
||||
fVec grad_vec1, grad_vec2;
|
||||
fVec state_sum_vec1, state_sum_vec2;
|
||||
int64_t d = 0;
|
||||
for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
|
||||
lpVec param_lpvec = lpVec::loadu(param_ptr + d);
|
||||
std::tie(param_vec1, param_vec2) = vec::convert_to_float<scalar_t>(param_lpvec);
|
||||
lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
|
||||
std::tie(grad_vec1, grad_vec2) = vec::convert_to_float<scalar_t>(grad_lpvec);
|
||||
if (grad_scale_ptr) {
|
||||
grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
|
||||
grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
|
||||
grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
|
||||
grad_vec_to_store.store(grad_ptr + d);
|
||||
}
|
||||
if (maximize){
|
||||
grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
|
||||
grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
|
||||
}
|
||||
if (weight_decay != 0.0){
|
||||
grad_vec1 += param_vec1 * fVec(scalar_t(weight_decay));
|
||||
grad_vec2 += param_vec2 * fVec(scalar_t(weight_decay));
|
||||
}
|
||||
std::tie(state_sum_vec1, state_sum_vec2) = vec::convert_to_float<scalar_t>(lpVec::loadu(state_sum_ptr + d));
|
||||
state_sum_vec1 += grad_vec1 * grad_vec1;
|
||||
state_sum_vec2 += grad_vec2 * grad_vec2;
|
||||
vec::convert_from_float<scalar_t>(state_sum_vec1, state_sum_vec2).store(state_sum_ptr + d);
|
||||
|
||||
fVec std_vec1 = state_sum_vec1.sqrt() + fVec(scalar_t(eps));
|
||||
fVec std_vec2 = state_sum_vec2.sqrt() + fVec(scalar_t(eps));
|
||||
param_vec1 = param_vec1 - fVec(scalar_t(clr)) * grad_vec1 / std_vec1;
|
||||
param_vec2 = param_vec2 - fVec(scalar_t(clr)) * grad_vec2 / std_vec2;
|
||||
vec::convert_from_float<scalar_t>(param_vec1, param_vec2).store(param_ptr + d);
|
||||
}
|
||||
scalar_t grad_val_to_store;
|
||||
for (; d < size; d++) {
|
||||
opmath_t grad_val = grad_ptr[d];
|
||||
opmath_t param_val = param_ptr[d];
|
||||
if (grad_scale_ptr) {
|
||||
grad_val = grad_ptr[d] / opmath_t(*grad_scale_ptr);
|
||||
grad_val_to_store = grad_val;
|
||||
grad_ptr[d] = grad_val_to_store;
|
||||
}
|
||||
if (maximize) grad_val = -grad_val;
|
||||
if (weight_decay != 0.0){
|
||||
grad_val += param_val * opmath_t(weight_decay);
|
||||
}
|
||||
opmath_t state_sum_val = state_sum_ptr[d];
|
||||
state_sum_val += grad_val * grad_val;
|
||||
state_sum_ptr[d] = state_sum_val;
|
||||
opmath_t std_val = std::sqrt(state_sum_val) + opmath_t(eps);
|
||||
param_val -= opmath_t(clr) * grad_val / std_val;
|
||||
param_ptr[d] = param_val;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, typename opmath_t>
|
||||
typename std::enable_if<
|
||||
std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
|
||||
void>::
|
||||
type inline adagrad_math(
|
||||
scalar_t* param_ptr,
|
||||
scalar_t* grad_ptr,
|
||||
scalar_t* state_sum_ptr,
|
||||
const double clr,
|
||||
const double eps,
|
||||
const double weight_decay,
|
||||
const bool maximize,
|
||||
const float* grad_scale_ptr,
|
||||
int64_t size
|
||||
){
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
Vec grad_vec_to_store;
|
||||
int64_t d = 0;
|
||||
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
||||
Vec param_vec = Vec::loadu(param_ptr + d);
|
||||
Vec grad_vec = Vec::loadu(grad_ptr + d);
|
||||
if (grad_scale_ptr) {
|
||||
grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
|
||||
grad_vec_to_store = grad_vec;
|
||||
grad_vec_to_store.store(grad_ptr + d);
|
||||
}
|
||||
if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
|
||||
if (weight_decay != 0.0){
|
||||
grad_vec += param_vec * Vec(scalar_t(weight_decay));
|
||||
}
|
||||
|
||||
Vec sum_vec = Vec::loadu(state_sum_ptr + d) + grad_vec * grad_vec;
|
||||
sum_vec.store(state_sum_ptr + d);
|
||||
|
||||
Vec std_vec = sum_vec.sqrt() + Vec(scalar_t(eps));
|
||||
param_vec = param_vec - Vec(scalar_t(clr)) * grad_vec / std_vec;
|
||||
param_vec.store(param_ptr + d);
|
||||
}
|
||||
scalar_t grad_val_to_store;
|
||||
for (; d < size; d++) {
|
||||
scalar_t grad_val = grad_ptr[d];
|
||||
if (grad_scale_ptr) {
|
||||
grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
|
||||
grad_val_to_store = grad_val;
|
||||
grad_ptr[d] = grad_val_to_store;
|
||||
}
|
||||
if (maximize) grad_val = -grad_val;
|
||||
if (weight_decay != 0.0){
|
||||
grad_val += param_ptr[d] * scalar_t(weight_decay);
|
||||
}
|
||||
state_sum_ptr[d] += grad_val * grad_val;
|
||||
|
||||
scalar_t std_val = std::sqrt(state_sum_ptr[d]) + scalar_t(eps);
|
||||
param_ptr[d] -= scalar_t(clr) * grad_val / std_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void adagrad_fused_step_impl(
|
||||
const at::Tensor& param,
|
||||
const at::Tensor& grad,
|
||||
const at::Tensor& state_sum,
|
||||
const at::Tensor& state_step,
|
||||
const double lr,
|
||||
const double lr_decay,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const float* grad_scale_ptr) {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
scalar_t* param_data = param.data_ptr<scalar_t>();
|
||||
scalar_t* grad_data = grad.data_ptr<scalar_t>();
|
||||
scalar_t* state_sum_data = state_sum.data_ptr<scalar_t>();
|
||||
double step = state_step.item<float>();
|
||||
double clr = lr / (1.0 + (step - 1.0) * lr_decay);
|
||||
|
||||
constexpr size_t cache_line_size = 64;
|
||||
constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
|
||||
size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
|
||||
|
||||
auto adagrad_fn = [&](int64_t begin, int64_t end) {
|
||||
// local pointers
|
||||
begin *= cache_line_aligned_task_unit;
|
||||
end = std::min(end * cache_line_aligned_task_unit, param.numel());
|
||||
scalar_t* param_ptr = param_data + begin;
|
||||
scalar_t* grad_ptr = grad_data + begin;
|
||||
scalar_t* state_sum_ptr = state_sum_data + begin;
|
||||
|
||||
const int64_t size = end - begin;
|
||||
adagrad_math<scalar_t, opmath_t>(
|
||||
param_ptr,
|
||||
grad_ptr,
|
||||
state_sum_ptr,
|
||||
clr,
|
||||
eps,
|
||||
weight_decay,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
size
|
||||
);
|
||||
};
|
||||
at::parallel_for(
|
||||
0, num_units, 0, adagrad_fn);
|
||||
}
|
||||
|
||||
void fused_adagrad_kernel(
|
||||
const at::Tensor& param,
|
||||
const at::Tensor& grad,
|
||||
const at::Tensor& state_sum,
|
||||
const at::Tensor& state_step,
|
||||
const double lr,
|
||||
const double lr_decay,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const float* grad_scale_ptr
|
||||
) {
|
||||
Tensor grad_contiguous = grad.contiguous();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_adagrad_kernel", [&] {
|
||||
adagrad_fused_step_impl<scalar_t>(
|
||||
param,
|
||||
grad,
|
||||
state_sum,
|
||||
state_step,
|
||||
lr,
|
||||
lr_decay,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fused_adagrad_stub, &fused_adagrad_kernel);
|
||||
} // namespace at::native
|
||||
@ -341,46 +341,12 @@ inline void tinygemm_kernel(
|
||||
|
||||
#if !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
#include <arm_neon.h>
|
||||
|
||||
inline float32x4x2_t load_as_float32x4x2(const Half* ptr) {
|
||||
float16x4x2_t f16_val = vld2_f16(reinterpret_cast<const float16_t *>(ptr));
|
||||
auto val_low = vcvt_f32_f16(f16_val.val[0]);
|
||||
auto val_high = vcvt_f32_f16(f16_val.val[1]);
|
||||
return {val_low, val_high};
|
||||
}
|
||||
|
||||
inline void store_float32x4(Half* ptr, float32x4_t val) {
|
||||
vst1_f16(reinterpret_cast<float16_t*>(ptr), vcvt_f16_f32(val));
|
||||
}
|
||||
|
||||
inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) {
|
||||
int32x4_t shift = vdupq_n_s32(16);
|
||||
uint16x4x2_t u16_val = vld2_u16(reinterpret_cast<const uint16_t *>(ptr));
|
||||
uint32x4_t int_low = vmovl_u16(u16_val.val[0]);
|
||||
uint32x4_t int_high = vmovl_u16(u16_val.val[1]);
|
||||
return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))};
|
||||
}
|
||||
|
||||
inline void store_float32x4(BFloat16* ptr, float32x4_t val) {
|
||||
int32x4_t shift = vdupq_n_s32(-16);
|
||||
uint32x4_t uint32_val = vshlq_u32(vreinterpretq_u32_f32(val), shift);
|
||||
vst1_u16(reinterpret_cast<uint16_t*>(ptr), vmovn_u32(uint32_val));
|
||||
}
|
||||
|
||||
inline float32x4x2_t load_as_float32x4x2(const float* ptr) {
|
||||
return vld2q_f32(ptr);
|
||||
}
|
||||
|
||||
inline void store_float32x4(float* ptr, float32x4_t val) {
|
||||
vst1q_f32(ptr, val);
|
||||
}
|
||||
|
||||
template <int BLOCK_M, int BLOCK_N, typename T>
|
||||
inline void tinygemm_kernel_(
|
||||
const T* RESTRICT A,
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
inline void tinygemm_kernel(
|
||||
const Half* RESTRICT A,
|
||||
const uint8_t* RESTRICT B,
|
||||
const T* RESTRICT ScaleAndZeros,
|
||||
T* RESTRICT C,
|
||||
const Half* RESTRICT ScaleAndZeros,
|
||||
Half* RESTRICT C,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
@ -402,9 +368,9 @@ inline void tinygemm_kernel_(
|
||||
if (is_block_start(k, BLOCK_K)) {
|
||||
int kb = k / BLOCK_K;
|
||||
c10::ForcedUnroll<4>{}([&](auto i) {
|
||||
auto scales_and_zeros = load_as_float32x4x2(ScaleAndZeros + kb * ldc * 2 + n * 2 + i * 8);
|
||||
scales[i] = scales_and_zeros.val[0];
|
||||
zeros[i] = scales_and_zeros.val[1];
|
||||
auto scales_and_zeros = vld2_f16(reinterpret_cast<const float16_t*>(ScaleAndZeros + kb * ldc * 2 + n * 2 + i * 8));
|
||||
scales[i] = vcvt_f32_f16(scales_and_zeros.val[0]);
|
||||
zeros[i] = vcvt_f32_f16(scales_and_zeros.val[1]);
|
||||
});
|
||||
}
|
||||
c10::ForcedUnroll<4>{}([&](auto i) {
|
||||
@ -417,53 +383,11 @@ inline void tinygemm_kernel_(
|
||||
});
|
||||
}
|
||||
c10::ForcedUnroll<4>{}([&](auto i) {
|
||||
store_float32x4(C + m * ldc + n + i * 4, c_val[i]);
|
||||
vst1_f16(reinterpret_cast<float16_t*>(C + m * ldc + n + i * 4), vcvt_f16_f32(c_val[i]));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
inline void tinygemm_kernel(
|
||||
const Half* RESTRICT A,
|
||||
const uint8_t* RESTRICT B,
|
||||
const Half* RESTRICT ScaleAndZeros,
|
||||
Half* RESTRICT C,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int K,
|
||||
int BLOCK_K) {
|
||||
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
|
||||
}
|
||||
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
inline void tinygemm_kernel(
|
||||
const BFloat16* RESTRICT A,
|
||||
const uint8_t* RESTRICT B,
|
||||
const BFloat16* RESTRICT ScaleAndZeros,
|
||||
BFloat16* RESTRICT C,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int K,
|
||||
int BLOCK_K) {
|
||||
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
|
||||
}
|
||||
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
inline void tinygemm_kernel(
|
||||
const float* RESTRICT A,
|
||||
const uint8_t* RESTRICT B,
|
||||
const float* RESTRICT ScaleAndZeros,
|
||||
float* RESTRICT C,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int K,
|
||||
int BLOCK_K) {
|
||||
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<int BLOCK_N>
|
||||
|
||||
@ -250,18 +250,10 @@ inline void tinygemm_kernel_(
|
||||
});
|
||||
}
|
||||
|
||||
#if __OPTIMIZE__
|
||||
float32x4_t scale_val = load_as_float32x4(scales);
|
||||
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
|
||||
C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i);
|
||||
});
|
||||
#else
|
||||
// Workaround GCCs inability to infer lane index at compile time
|
||||
// See https://github.com/pytorch/pytorch/issues/126283
|
||||
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
|
||||
C[m * ldc + i] = reduce(c_val[i]) * float(scales[i]);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -32,7 +32,6 @@
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#endif
|
||||
@ -989,11 +988,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
else
|
||||
#endif
|
||||
{
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60200
|
||||
// hipBlasLT requires scaleD to be set to something in order to use AMAX
|
||||
auto dummy_options = TensorOptions().dtype(kFloat).device(kCUDA);
|
||||
auto dummy_scale = at::ones(1, dummy_options);
|
||||
#endif
|
||||
at::cuda::blas::scaled_gemm(
|
||||
args.transa,
|
||||
args.transb,
|
||||
@ -1011,19 +1005,15 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
bias ? bias->data_ptr(): nullptr,
|
||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
||||
args.result->data_ptr(),
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60200
|
||||
scale_result ? scale_result->data_ptr() : dummy_scale.data_ptr(),
|
||||
#else
|
||||
scale_result ? scale_result->data_ptr() : nullptr,
|
||||
#endif
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
amax.data_ptr(),
|
||||
use_fast_accum);
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200
|
||||
// ROCm's hipBLASLt does not support amax before 6.2, so calculate separately
|
||||
#if defined(USE_ROCM)
|
||||
// rocm's hipblaslt does not yet support amax, so calculate separately
|
||||
amax = at::max(at::abs(out.to(kFloat)));
|
||||
#endif
|
||||
|
||||
|
||||
@ -86,8 +86,12 @@ struct FusedSgdMathFunctor {
|
||||
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc)};
|
||||
const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
const auto use_faster_load_store =
|
||||
(n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned;
|
||||
#else
|
||||
const auto use_faster_load_store{false};
|
||||
#endif
|
||||
if (use_faster_load_store) {
|
||||
for (auto i_start = threadIdx.x;
|
||||
i_start * kILP < n && i_start * kILP < chunk_size;
|
||||
|
||||
@ -3762,18 +3762,6 @@
|
||||
# This function should be deprecated in favor of differential_analytic_matrix_function in FunctionsManual.cpp
|
||||
- func: matrix_exp_backward(Tensor self, Tensor grad) -> Tensor
|
||||
|
||||
# DEPRECATED: Use torch.aminmax instead
|
||||
- func: _aminmax(Tensor self) -> (Tensor, Tensor)
|
||||
dispatch:
|
||||
CPU, CUDA: _aminmax_all
|
||||
autogen: _aminmax.out
|
||||
|
||||
# DEPRECATED: Use torch.aminmax instead
|
||||
- func: _aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)
|
||||
dispatch:
|
||||
CPU, CUDA: _aminmax
|
||||
autogen: _aminmax.dim_out
|
||||
|
||||
- func: aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max)
|
||||
device_check: NoCheck # TensorIterator
|
||||
structured_delegate: aminmax.out
|
||||
@ -7714,7 +7702,7 @@
|
||||
|
||||
- func: result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType
|
||||
|
||||
- func: can_cast(ScalarType from_, ScalarType to) -> bool
|
||||
- func: can_cast(ScalarType from, ScalarType to) -> bool
|
||||
variants: function
|
||||
|
||||
- func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType
|
||||
@ -14720,13 +14708,13 @@
|
||||
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _flash_attention_forward
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)
|
||||
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
|
||||
device_check: NoCheck
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -15539,6 +15527,7 @@
|
||||
CPU: foobar
|
||||
autogen: _foobar.out
|
||||
|
||||
# Fused Optimizer CUDA kernels.
|
||||
- func: _fused_adam_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
|
||||
# Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
|
||||
variants: function
|
||||
@ -15593,12 +15582,6 @@
|
||||
CUDA: _fused_sgd_kernel_cuda_
|
||||
autogen: _fused_sgd.tensor_lr, _fused_sgd.tensor_lr_out
|
||||
|
||||
- func: _fused_adagrad_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _fused_adagrad_kernel_cpu_
|
||||
autogen: _fused_adagrad, _fused_adagrad.out
|
||||
|
||||
# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts.
|
||||
- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
|
||||
variants: function
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
#include <ATen/ops/_aminmax.h>
|
||||
#include <ATen/ops/aminmax.h>
|
||||
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.h>
|
||||
#include <ATen/ops/fake_quantize_per_channel_affine.h>
|
||||
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask.h>
|
||||
@ -148,7 +148,7 @@ void _calculate_moving_average(
|
||||
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (per_row_fq) {
|
||||
std::tie(x_min, x_max) = at::_aminmax(x, 1);
|
||||
std::tie(x_min, x_max) = at::aminmax(x, 1);
|
||||
float* x_min_data = x_min.data_ptr<float>();
|
||||
float* x_max_data = x_max.data_ptr<float>();
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
@ -165,7 +165,7 @@ void _calculate_moving_average(
|
||||
size);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
std::tie(x_min, x_max) = at::_aminmax(x);
|
||||
std::tie(x_min, x_max) = at::aminmax(x);
|
||||
float* x_min_data = x_min.data_ptr<float>();
|
||||
float* x_max_data = x_max.data_ptr<float>();
|
||||
// Moving Average Min/Max observer for activations
|
||||
|
||||
@ -841,9 +841,7 @@ _flash_attention_forward(
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
bool return_debug_mask,
|
||||
std::optional<double> scale,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right) {
|
||||
std::optional<double> scale) {
|
||||
#if defined(USE_FLASH_ATTENTION)
|
||||
const auto softmax_scale =
|
||||
sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
@ -854,9 +852,6 @@ _flash_attention_forward(
|
||||
std::optional<Tensor> seqused_k = c10::nullopt;
|
||||
std::optional<Tensor> alibi_slopes = c10::nullopt;
|
||||
|
||||
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
||||
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
|
||||
|
||||
// We are going to have two paths:
|
||||
// 1. The standard MHA path for dense tensors
|
||||
// 2. The Varseqlen path
|
||||
@ -891,8 +886,8 @@ _flash_attention_forward(
|
||||
softmax_scale,
|
||||
false /*zero_tensors*/,
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
-1, /*window_size_left*/
|
||||
-1, /*window_size_right*/
|
||||
return_debug_mask,
|
||||
c10::nullopt /*gen_*/);
|
||||
} else {
|
||||
@ -914,8 +909,8 @@ _flash_attention_forward(
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
-1, /*window_size_left*/
|
||||
-1, /*window_size_right*/
|
||||
return_debug_mask, /*return_softmax (this is used for testing)*/
|
||||
c10::nullopt);
|
||||
}
|
||||
|
||||
@ -66,18 +66,13 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
bool is_causal,
|
||||
const Tensor& philox_seed,
|
||||
const Tensor& philox_offset,
|
||||
std::optional<double> scale,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right) {
|
||||
std::optional<double> scale) {
|
||||
#if defined(USE_FLASH_ATTENTION)
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
// CUDA code assumes that dout is contiguous
|
||||
auto contiguous_grad_out = grad_out.contiguous();
|
||||
auto contiguous_out = out.contiguous();
|
||||
|
||||
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
||||
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
|
||||
|
||||
std::optional<at::Tensor> dq{c10::nullopt};
|
||||
std::optional<at::Tensor> dk{c10::nullopt};
|
||||
std::optional<at::Tensor> dv{c10::nullopt};
|
||||
@ -123,8 +118,8 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
softmax_scale,
|
||||
false /*zero_tensors*/,
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
-1, /*window_size_left*/
|
||||
-1, /*window_size_right*/
|
||||
determinisitic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
@ -145,8 +140,8 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
-1, /*window_size_left*/
|
||||
-1, /*window_size_right*/
|
||||
determinisitic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
|
||||
@ -788,7 +788,7 @@ TEST_F(VulkanAPITest, avg_pool2d) {
|
||||
ASSERT_TRUE(check);
|
||||
}
|
||||
|
||||
TEST_F(VulkanAPITest, DISABLED_batch_norm_invalid_inputs) {
|
||||
TEST_F(VulkanAPITest, batch_norm_invalid_inputs) {
|
||||
c10::InferenceMode mode;
|
||||
|
||||
// Act: Vulkan batchnorm only supports evaluation mode
|
||||
|
||||
@ -150,7 +150,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,46
|
||||
hf_BigBird,pass,0
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -98,7 +98,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass, 52
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -138,7 +138,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_to_run,0
|
||||
hf_BigBird,fail_accuracy,0
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -150,7 +150,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,46
|
||||
hf_BigBird,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -98,7 +98,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,52
|
||||
hf_BigBird,fail_to_run,3
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -150,7 +150,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_accuracy,46
|
||||
hf_BigBird,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -98,7 +98,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,52
|
||||
hf_BigBird,fail_to_run,3
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -150,7 +150,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,46
|
||||
hf_BigBird,pass,0
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -98,7 +98,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,52
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -150,7 +150,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_accuracy,46
|
||||
hf_BigBird,fail_accuracy,0
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -98,7 +98,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,52
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -354,24 +354,6 @@ def patch_torch_manual_seed():
|
||||
torch.manual_seed = deterministic_torch_manual_seed
|
||||
|
||||
|
||||
def empty_gpu_cache(device):
|
||||
"""
|
||||
Explicitly empty gpu cache to avoid OOM in subsequent run.
|
||||
"""
|
||||
|
||||
if device not in ["cuda", "xpu"]:
|
||||
log.warning(
|
||||
"Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]",
|
||||
device,
|
||||
)
|
||||
return
|
||||
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
elif device == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
|
||||
def synchronize():
|
||||
pass
|
||||
|
||||
@ -1252,7 +1234,7 @@ def download_retry_decorator(download_fn):
|
||||
)
|
||||
time.sleep(wait)
|
||||
else:
|
||||
raise RuntimeError( # noqa: B904
|
||||
raise RuntimeError( # noqa: TRY200
|
||||
f"Failed to load model '{args}' with following error(s): {str(e)}."
|
||||
)
|
||||
|
||||
@ -2296,7 +2278,7 @@ class BenchmarkRunner:
|
||||
def batch_size_finder(self, device, model_name, initial_batch_size=1024):
|
||||
batch_size = initial_batch_size
|
||||
while batch_size >= 1:
|
||||
empty_gpu_cache(current_device)
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
device, name, model, example_inputs, _ = self.load_model(
|
||||
device,
|
||||
@ -2486,7 +2468,7 @@ class BenchmarkRunner:
|
||||
fp64_outputs = None
|
||||
finally:
|
||||
del model_fp64, inputs_fp64
|
||||
empty_gpu_cache(current_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
|
||||
self.args.training, current_device, name
|
||||
@ -2515,7 +2497,7 @@ class BenchmarkRunner:
|
||||
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
||||
finally:
|
||||
del model_copy
|
||||
empty_gpu_cache(current_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Rerun native pytorch
|
||||
reset_rng_state()
|
||||
@ -2536,7 +2518,7 @@ class BenchmarkRunner:
|
||||
return record_status(accuracy_status, dynamo_start_stats=start_stats)
|
||||
finally:
|
||||
del model_copy
|
||||
empty_gpu_cache(current_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Two eager runs should have exactly same result
|
||||
is_same = True
|
||||
@ -2737,7 +2719,7 @@ class BenchmarkRunner:
|
||||
try:
|
||||
if current_device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
empty_gpu_cache(current_device)
|
||||
torch.cuda.empty_cache()
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(niters):
|
||||
fn(model, example_inputs)
|
||||
@ -2967,7 +2949,7 @@ class BenchmarkRunner:
|
||||
name, model, example_inputs, optimize_ctx, experiment, tag
|
||||
)
|
||||
print(status)
|
||||
empty_gpu_cache(current_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.maybe_preserve_compile_debug(name, status)
|
||||
|
||||
|
||||
@ -1,22 +1,12 @@
|
||||
#name,data_type,shape,wrapper,perf_speedup_target_c7i_metal_24xl
|
||||
#timm_vision_transformer,float32,static,default,1.039510755
|
||||
phlippe_densenet,float32,static,default,1.3988316
|
||||
basic_gnn_gcn,float32,dynamic,default,1.074576405
|
||||
llama_v2_7b_16h,float32,dynamic,default,1.211740245
|
||||
resnet50,float32,dynamic,default,1.65984261
|
||||
timm_efficientnet,float32,static,cpp,2.271561735
|
||||
mobilenet_v3_large,float32,static,cpp,2.63375628
|
||||
timm_resnest,float32,dynamic,cpp,1.67998548
|
||||
pyhpc_turbulent_kinetic_energy,float32,dynamic,cpp,1.59968463
|
||||
#hf_GPT2,float32,dynamic,cpp,
|
||||
hf_GPT2,float32,dynamic,cpp,1.379885175
|
||||
resnext50_32x4d,amp,static,default,1.461687045
|
||||
vgg16,amp,static,default,1.267194285
|
||||
hf_Longformer,amp,dynamic,default,0.997006035
|
||||
hf_Bert_large,amp,dynamic,default,0.99391146
|
||||
llama,amp,static,default,1.32950568
|
||||
timm_regnet,amp,static,cpp,1.157188305
|
||||
lennard_jones,amp,static,cpp,2.240104485
|
||||
hf_T5_generate,amp,dynamic,cpp,1.447656135
|
||||
timm_vovnet,amp,dynamic,cpp,1.07856471
|
||||
mobilenet_v2,amp,dynamic,cpp,2.27774577
|
||||
#name,data_type,shape,wrapper,perf_speedup_target_c5_12xlarge
|
||||
#timm_vision_transformer,float32,static,default,1.1585628
|
||||
phlippe_densenet,float32,static,default,1.99590617
|
||||
basic_gnn_gcn,float32,dynamic,default,1.24639561
|
||||
llama_v2_7b_16h,float32,dynamic,default,1.27455818
|
||||
resnet50,float32,dynamic,default,2.28794694
|
||||
timm_efficientnet,float32,static,cpp,2.72195686
|
||||
mobilenet_v3_large,float32,static,cpp,3.02274304
|
||||
timm_resnest,float32,dynamic,cpp,2.10118744
|
||||
shufflenet_v2_x1_0,float32,dynamic,cpp,1.8976929
|
||||
#hf_GPT2,float32,dynamic,cpp,1.6702305
|
||||
hf_GPT2,float32,dynamic,cpp,1.1183002
|
||||
|
||||
|
@ -15,10 +15,6 @@ from torch._dynamo.utils import clone_inputs
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Enable FX graph caching
|
||||
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
|
||||
torch._inductor.config.fx_graph_cache = True
|
||||
|
||||
|
||||
def pip_install(package):
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
||||
|
||||
@ -13,10 +13,6 @@ from common import BenchmarkRunner, download_retry_decorator, main
|
||||
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
|
||||
from torch._dynamo.utils import clone_inputs
|
||||
|
||||
# Enable FX graph caching
|
||||
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
|
||||
torch._inductor.config.fx_graph_cache = True
|
||||
|
||||
|
||||
def pip_install(package):
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
||||
|
||||
@ -3,7 +3,7 @@ import itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -29,32 +29,28 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentConfig:
|
||||
shape: Tuple[int]
|
||||
batch_size: int
|
||||
num_heads: int
|
||||
q_seq_len: int
|
||||
k_seq_len: int
|
||||
head_dim: int
|
||||
score_mod: Callable
|
||||
dtype: torch.dtype
|
||||
calculate_bwd_time: bool
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.shape) == 4, "Shape must be of length 4"
|
||||
|
||||
def asdict(self):
|
||||
# Convert the dataclass instance to a dictionary
|
||||
d = asdict(self)
|
||||
# Remove the 'calculate_bwd_time' key
|
||||
d.pop("calculate_bwd_time", None)
|
||||
return d
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Times:
|
||||
eager_time: float
|
||||
compiled_time: float
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentResults:
|
||||
fwd_times: Times
|
||||
bwd_times: Optional[Times]
|
||||
eager_time: float
|
||||
compiled_time: float
|
||||
|
||||
def get_entries(self) -> List:
|
||||
return [
|
||||
f"{self.eager_time:2f}",
|
||||
f"{self.compiled_time:2f}",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -62,31 +58,29 @@ class Experiment:
|
||||
config: ExperimentConfig
|
||||
results: ExperimentResults
|
||||
|
||||
def get_entries(self) -> List:
|
||||
return self.config.get_entries() + self.results.get_entries()
|
||||
|
||||
def asdict(self):
|
||||
dict1 = self.config.asdict()
|
||||
dict1 = asdict(self.config)
|
||||
dict2 = asdict(self.results)
|
||||
return {**dict1, **dict2}
|
||||
|
||||
|
||||
def generate_inputs(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
q_sequence_length: int,
|
||||
kv_sequence_length: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
requires_grad: bool,
|
||||
batch_size,
|
||||
num_heads,
|
||||
q_sequence_length,
|
||||
kv_sequence_length,
|
||||
head_dim,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
q_shape = (batch_size, q_sequence_length, num_heads * head_dim)
|
||||
kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim)
|
||||
|
||||
make_q = partial(
|
||||
torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad
|
||||
)
|
||||
make_kv = partial(
|
||||
torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad
|
||||
)
|
||||
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
|
||||
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
|
||||
query = (
|
||||
make_q()
|
||||
.view(batch_size, q_sequence_length, num_heads, head_dim)
|
||||
@ -107,16 +101,14 @@ def generate_inputs(
|
||||
|
||||
def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
|
||||
device = torch.device("cuda")
|
||||
batch_size, num_heads, q_seq_len, head_dim = config.shape
|
||||
query, key, value = generate_inputs(
|
||||
batch_size,
|
||||
num_heads,
|
||||
q_seq_len,
|
||||
q_seq_len,
|
||||
head_dim,
|
||||
config.batch_size,
|
||||
config.num_heads,
|
||||
config.q_seq_len,
|
||||
config.k_seq_len,
|
||||
config.head_dim,
|
||||
config.dtype,
|
||||
device,
|
||||
requires_grad=config.calculate_bwd_time,
|
||||
)
|
||||
|
||||
def eager_sdpa(query, key, value, _):
|
||||
@ -133,47 +125,23 @@ def run_single_experiment(config: ExperimentConfig, dynamic=False) -> Experiment
|
||||
compiled_sdpa, query, key, value, score_mod
|
||||
)
|
||||
|
||||
if config.calculate_bwd_time:
|
||||
out_eager = eager_sdpa(query, key, value, score_mod)
|
||||
dOut = torch.randn_like(out_eager)
|
||||
backward_eager_time = benchmark_torch_function_in_microseconds(
|
||||
out_eager.backward, dOut, retain_graph=True
|
||||
)
|
||||
|
||||
out_compile = compiled_sdpa(query, key, value, score_mod)
|
||||
dOut = torch.randn_like(out_eager)
|
||||
backward_compile_time = benchmark_torch_function_in_microseconds(
|
||||
out_compile.backward, dOut, retain_graph=True
|
||||
)
|
||||
|
||||
return ExperimentResults(
|
||||
fwd_times=Times(forward_eager_time, forward_compiled_time),
|
||||
bwd_times=Times(backward_eager_time, backward_compile_time),
|
||||
)
|
||||
else:
|
||||
return ExperimentResults(
|
||||
fwd_times=Times(forward_eager_time, forward_compiled_time),
|
||||
bwd_times=None,
|
||||
)
|
||||
return ExperimentResults(
|
||||
eager_time=forward_eager_time,
|
||||
compiled_time=forward_compiled_time,
|
||||
)
|
||||
|
||||
|
||||
def calculate_speedup(results: ExperimentResults, type: str) -> float:
|
||||
if type == "fwd":
|
||||
return results.fwd_times.eager_time / results.fwd_times.compiled_time
|
||||
elif type == "bwd":
|
||||
assert results.bwd_times is not None
|
||||
return results.bwd_times.eager_time / results.bwd_times.compiled_time
|
||||
else:
|
||||
raise ValueError(f"Invalid type {type}")
|
||||
def calculate_speedup(results: ExperimentResults) -> float:
|
||||
return results.eager_time / results.compiled_time
|
||||
|
||||
|
||||
def get_func_name(func):
|
||||
return func.__name__.split("<locals>.")[-1].split(" at ")[0]
|
||||
|
||||
|
||||
def get_average_speedups(results: List[Experiment], type: str):
|
||||
def get_average_speedups(results: List[Experiment]):
|
||||
# Calculate speedups
|
||||
speedups = [calculate_speedup(r.results, type) for r in results]
|
||||
speedups = [calculate_speedup(r.results) for r in results]
|
||||
|
||||
# Find indices of max and min speedups
|
||||
max_speedup_index = np.argmax(speedups)
|
||||
@ -209,39 +177,20 @@ def print_results(results: List[Experiment]):
|
||||
table_data = defaultdict(list)
|
||||
for experiment in results:
|
||||
for key, value in experiment.asdict().items():
|
||||
if key == "fwd_times":
|
||||
for name, time in value.items():
|
||||
table_data[f"fwd_{name}"].append(float(time))
|
||||
elif key == "bwd_times":
|
||||
if experiment.config.calculate_bwd_time:
|
||||
for name, time in value.items():
|
||||
table_data[f"bwd_{name}"].append(float(time))
|
||||
else:
|
||||
table_data[key].append(value)
|
||||
if key == "eager_time" or key == "compiled_time":
|
||||
value = float(value)
|
||||
table_data[key].append(value)
|
||||
|
||||
# Calculate speedups
|
||||
fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results]
|
||||
table_data["fwd_speedup"] = fwd_speedups
|
||||
if results[0].config.calculate_bwd_time:
|
||||
bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results]
|
||||
table_data["bwd_speedup"] = bwd_speedups
|
||||
speedups = [calculate_speedup(r.results) for r in results]
|
||||
table_data["speedup"] = speedups
|
||||
|
||||
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
|
||||
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
print("\n")
|
||||
print("FWD Speedups".center(125, "="))
|
||||
print("\n")
|
||||
average_data = get_average_speedups(results, type="fwd")
|
||||
average_data = get_average_speedups(results)
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
if results[0].config.calculate_bwd_time:
|
||||
print("\n")
|
||||
print("BWD Speedups".center(125, "="))
|
||||
print("\n")
|
||||
average_data = get_average_speedups(results, type="bwd")
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
|
||||
def generate_score_mods() -> List[Callable]:
|
||||
def noop(score, b, h, m, n):
|
||||
@ -259,8 +208,8 @@ def generate_score_mods() -> List[Callable]:
|
||||
return [noop, causal_mask, relative_bias, head_bias]
|
||||
|
||||
|
||||
def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
|
||||
batch_sizes = [2, 8, 16]
|
||||
def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
batch_sizes = [1, 8, 16]
|
||||
num_heads = [16]
|
||||
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
|
||||
head_dims = [64, 128, 256]
|
||||
@ -279,49 +228,41 @@ def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
|
||||
) in itertools.product(
|
||||
batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes
|
||||
):
|
||||
assert q_seq_len == kv_seq_len, "Only equal length inputs supported for now."
|
||||
all_configs.append(
|
||||
ExperimentConfig(
|
||||
shape=(bsz, n_heads, q_seq_len, head_dim),
|
||||
batch_size=bsz,
|
||||
num_heads=n_heads,
|
||||
q_seq_len=q_seq_len,
|
||||
k_seq_len=kv_seq_len,
|
||||
head_dim=head_dim,
|
||||
score_mod=score_mod,
|
||||
dtype=dtype,
|
||||
calculate_bwd_time=calculate_bwd,
|
||||
)
|
||||
)
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def main(dynamic: bool, calculate_bwd: bool):
|
||||
def main(dynamic=False):
|
||||
seed = 123
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
results = []
|
||||
for config in tqdm(generate_experiment_configs(calculate_bwd)):
|
||||
for config in tqdm(generate_experiment_configs()):
|
||||
results.append(
|
||||
Experiment(config, run_single_experiment(config, dynamic=dynamic))
|
||||
)
|
||||
for config in tqdm(generate_experiment_configs(calculate_bwd)):
|
||||
results.append(Experiment(config, run_single_experiment(config)))
|
||||
|
||||
print_results(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set up the argument parser
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run sweep over sizes and score mods for flex attention"
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
action="store_true",
|
||||
help="Runs a dynamic shapes version of compiled flex attention.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--calculate-bwd", action="store_true", help="Calculate backward pass times"
|
||||
)
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.dynamic, args.calculate_bwd)
|
||||
main(args.dynamic)
|
||||
|
||||
@ -487,7 +487,6 @@ libtorch_core_sources = sorted(
|
||||
# These files are the only ones that are supported on Windows.
|
||||
libtorch_distributed_base_sources = [
|
||||
"torch/csrc/distributed/c10d/Backend.cpp",
|
||||
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
|
||||
"torch/csrc/distributed/c10d/FileStore.cpp",
|
||||
"torch/csrc/distributed/c10d/Functional.cpp",
|
||||
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
|
||||
@ -681,7 +680,6 @@ libtorch_cuda_distributed_extra_sources = [
|
||||
"torch/csrc/distributed/c10d/UCCUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
]
|
||||
@ -1175,7 +1173,6 @@ aten_native_source_codegen_list = [
|
||||
"aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp",
|
||||
"aten/src/ATen/native/cpu/FusedAdamKernel.cpp",
|
||||
"aten/src/ATen/native/cpu/FusedSGDKernel.cpp",
|
||||
"aten/src/ATen/native/cpu/FusedAdagradKernel.cpp",
|
||||
]
|
||||
|
||||
# This aten native source file list will not go through aten codegen process
|
||||
@ -1412,7 +1409,6 @@ aten_native_source_non_codegen_list = [
|
||||
"aten/src/ATen/native/xnnpack/Shim.cpp",
|
||||
"aten/src/ATen/native/FusedAdam.cpp",
|
||||
"aten/src/ATen/native/FusedSGD.cpp",
|
||||
"aten/src/ATen/native/FusedAdagrad.cpp",
|
||||
# Files not in native, but depends on native symbols
|
||||
# "aten/src/ATen/TensorIndexing.cpp",
|
||||
"aten/src/ATen/TensorIterator.cpp",
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
#include <c10/util/TypeSafeSignMath.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
#include <cmath>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
|
||||
@ -457,9 +457,6 @@ if(BUILD_LITE_INTERPRETER)
|
||||
append_filelist("libtorch_lite_cmake_sources" LIBTORCH_CMAKE_SRCS)
|
||||
list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS})
|
||||
list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_PROFILER_SRCS})
|
||||
if(USE_LITE_AOTI)
|
||||
append_filelist("inductor_core_resources" LIBTORCH_CMAKE_SRCS)
|
||||
endif()
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE TRUE)
|
||||
else()
|
||||
append_filelist("libtorch_cmake_sources" LIBTORCH_CMAKE_SRCS)
|
||||
|
||||
36
caffe2/python/CMakeLists.txt
Normal file
36
caffe2/python/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
# ---[ CPU files.
|
||||
set(Caffe2_CPU_PYTHON_SRCS
|
||||
"/pybind_state.cc"
|
||||
"/pybind_workspace.cc"
|
||||
"/pybind_state_dlpack.cc"
|
||||
"/pybind_state_nomni.cc"
|
||||
"/pybind_state_registry.cc"
|
||||
"/pybind_state_int8.cc"
|
||||
)
|
||||
|
||||
if(USE_MKLDNN)
|
||||
set(Caffe2_CPU_PYTHON_SRCS
|
||||
${Caffe2_CPU_PYTHON_SRCS}
|
||||
"/pybind_state_ideep.cc"
|
||||
)
|
||||
endif()
|
||||
|
||||
# ---[ GPU files
|
||||
set(Caffe2_GPU_PYTHON_SRCS
|
||||
${Caffe2_CPU_PYTHON_SRCS}
|
||||
"/pybind_state_gpu.cc"
|
||||
)
|
||||
|
||||
# ---[ HIP files
|
||||
set(Caffe2_HIP_PYTHON_SRCS
|
||||
${Caffe2_CPU_PYTHON_SRCS}
|
||||
"/pybind_state_hip.cc"
|
||||
)
|
||||
|
||||
prepend(Caffe2_CPU_PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR} ${Caffe2_CPU_PYTHON_SRCS})
|
||||
prepend(Caffe2_GPU_PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR} ${Caffe2_GPU_PYTHON_SRCS})
|
||||
prepend(Caffe2_HIP_PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR} ${Caffe2_HIP_PYTHON_SRCS})
|
||||
|
||||
set(Caffe2_CPU_PYTHON_SRCS ${Caffe2_CPU_PYTHON_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_GPU_PYTHON_SRCS ${Caffe2_GPU_PYTHON_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_HIP_PYTHON_SRCS ${Caffe2_HIP_PYTHON_SRCS} PARENT_SCOPE)
|
||||
86
caffe2/python/__init__.py
Normal file
86
caffe2/python/__init__.py
Normal file
@ -0,0 +1,86 @@
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
|
||||
try:
|
||||
from caffe2.proto import caffe2_pb2
|
||||
except ImportError:
|
||||
warnings.warn('Caffe2 support is no longer present in PyTorch.')
|
||||
raise
|
||||
|
||||
# TODO: refactor & remove the following alias
|
||||
caffe2_pb2.CPU = caffe2_pb2.PROTO_CPU
|
||||
caffe2_pb2.CUDA = caffe2_pb2.PROTO_CUDA
|
||||
caffe2_pb2.MKLDNN = caffe2_pb2.PROTO_MKLDNN
|
||||
caffe2_pb2.OPENGL = caffe2_pb2.PROTO_OPENGL
|
||||
caffe2_pb2.OPENCL = caffe2_pb2.PROTO_OPENCL
|
||||
caffe2_pb2.IDEEP = caffe2_pb2.PROTO_IDEEP
|
||||
caffe2_pb2.HIP = caffe2_pb2.PROTO_HIP
|
||||
caffe2_pb2.COMPILE_TIME_MAX_DEVICE_TYPES = caffe2_pb2.PROTO_COMPILE_TIME_MAX_DEVICE_TYPES
|
||||
|
||||
if sys.platform == "win32":
|
||||
is_conda = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
|
||||
py_dll_path = os.path.join(os.path.dirname(sys.executable), 'Library', 'bin')
|
||||
th_root = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'torch')
|
||||
th_dll_path = os.path.join(th_root, 'lib')
|
||||
|
||||
if not os.path.exists(os.path.join(th_dll_path, 'nvToolsExt64_1.dll')) and \
|
||||
not os.path.exists(os.path.join(py_dll_path, 'nvToolsExt64_1.dll')):
|
||||
nvtoolsext_dll_path = os.path.join(
|
||||
os.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt'), 'bin', 'x64')
|
||||
else:
|
||||
nvtoolsext_dll_path = ''
|
||||
|
||||
import importlib.util
|
||||
import glob
|
||||
spec = importlib.util.spec_from_file_location('torch_version', os.path.join(th_root, 'version.py'))
|
||||
torch_version = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(torch_version)
|
||||
if torch_version.cuda and len(glob.glob(os.path.join(th_dll_path, 'cudart64*.dll'))) == 0 and \
|
||||
len(glob.glob(os.path.join(py_dll_path, 'cudart64*.dll'))) == 0:
|
||||
cuda_version = torch_version.cuda
|
||||
cuda_version_1 = cuda_version.replace('.', '_')
|
||||
cuda_path_var = 'CUDA_PATH_V' + cuda_version_1
|
||||
default_path = 'C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v' + cuda_version
|
||||
cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), 'bin')
|
||||
else:
|
||||
cuda_path = ''
|
||||
|
||||
import ctypes
|
||||
kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True)
|
||||
dll_paths = list(filter(os.path.exists, [th_dll_path, py_dll_path, nvtoolsext_dll_path, cuda_path]))
|
||||
with_load_library_flags = hasattr(kernel32, 'AddDllDirectory')
|
||||
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
||||
|
||||
kernel32.LoadLibraryW.restype = ctypes.c_void_p
|
||||
if with_load_library_flags:
|
||||
kernel32.LoadLibraryExW.restype = ctypes.c_void_p
|
||||
|
||||
for dll_path in dll_paths:
|
||||
os.add_dll_directory(dll_path)
|
||||
|
||||
dlls = glob.glob(os.path.join(th_dll_path, '*.dll'))
|
||||
path_patched = False
|
||||
for dll in dlls:
|
||||
is_loaded = False
|
||||
if with_load_library_flags:
|
||||
res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
|
||||
last_error = ctypes.get_last_error()
|
||||
if res is None and last_error != 126:
|
||||
err = ctypes.WinError(last_error)
|
||||
err.strerror += ' Error loading "{}" or one of its dependencies.'.format(dll)
|
||||
raise err
|
||||
elif res is not None:
|
||||
is_loaded = True
|
||||
if not is_loaded:
|
||||
if not path_patched:
|
||||
os.environ['PATH'] = ';'.join(dll_paths + [os.environ['PATH']])
|
||||
path_patched = True
|
||||
res = kernel32.LoadLibraryW(dll)
|
||||
if res is None:
|
||||
err = ctypes.WinError(ctypes.get_last_error())
|
||||
err.strerror += ' Error loading "{}" or one of its dependencies.'.format(dll)
|
||||
raise err
|
||||
|
||||
kernel32.SetErrorMode(prev_error_mode)
|
||||
57
caffe2/python/_import_c_extension.py
Normal file
57
caffe2/python/_import_c_extension.py
Normal file
@ -0,0 +1,57 @@
|
||||
## @package _import_c_extension
|
||||
# Module caffe2.python._import_c_extension
|
||||
import atexit
|
||||
import logging
|
||||
import sys
|
||||
from caffe2.python import extension_loader
|
||||
|
||||
# We will first try to load the gpu-enabled caffe2. If it fails, we will then
|
||||
# attempt to load the cpu version. The cpu backend is the minimum required, so
|
||||
# if that still fails, we will exit loud.
|
||||
with extension_loader.DlopenGuard():
|
||||
has_hip_support = False
|
||||
has_cuda_support = False
|
||||
has_gpu_support = False
|
||||
|
||||
try:
|
||||
from caffe2.python.caffe2_pybind11_state_gpu import * # noqa
|
||||
if num_cuda_devices(): # noqa
|
||||
has_gpu_support = has_cuda_support = True
|
||||
except ImportError as gpu_e:
|
||||
logging.info('Failed to import cuda module: {}'.format(gpu_e))
|
||||
try:
|
||||
from caffe2.python.caffe2_pybind11_state_hip import * # noqa
|
||||
# we stop checking whether we have AMD GPU devices on the host,
|
||||
# because we may be constructing a net on a machine without GPU,
|
||||
# and run the net on another one with GPU
|
||||
has_gpu_support = has_hip_support = True
|
||||
logging.info('This caffe2 python run has AMD GPU support!')
|
||||
except ImportError as hip_e:
|
||||
logging.info('Failed to import AMD hip module: {}'.format(hip_e))
|
||||
|
||||
logging.warning(
|
||||
'This caffe2 python run failed to load cuda module:{},'
|
||||
'and AMD hip module:{}.'
|
||||
'Will run in CPU only mode.'.format(gpu_e, hip_e))
|
||||
try:
|
||||
from caffe2.python.caffe2_pybind11_state import * # noqa
|
||||
except ImportError as cpu_e:
|
||||
logging.critical(
|
||||
'Cannot load caffe2.python. Error: {0}'.format(str(cpu_e)))
|
||||
sys.exit(1)
|
||||
|
||||
# libcaffe2_python contains a global Workspace that we need to properly delete
|
||||
# when exiting. Otherwise, cudart will cause segfaults sometimes.
|
||||
atexit.register(on_module_exit) # noqa
|
||||
|
||||
|
||||
# Add functionalities for the TensorCPU interface.
|
||||
def _TensorCPU_shape(self):
|
||||
return tuple(self._shape)
|
||||
|
||||
|
||||
def _TensorCPU_reshape(self, shape):
|
||||
return self._reshape(list(shape))
|
||||
|
||||
TensorCPU.shape = property(_TensorCPU_shape) # noqa
|
||||
TensorCPU.reshape = _TensorCPU_reshape # noqa
|
||||
227
caffe2/python/_import_c_extension.pyi
Normal file
227
caffe2/python/_import_c_extension.pyi
Normal file
@ -0,0 +1,227 @@
|
||||
import collections
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union, overload
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import google.protobuf.message
|
||||
import torch
|
||||
from caffe2.proto import caffe2_pb2
|
||||
|
||||
from . import core
|
||||
|
||||
# pybind11 will automatically accept either Python str or bytes for C++ APIs
|
||||
# that accept std::string.
|
||||
_PybindStr: TypeAlias = Union[str, bytes]
|
||||
_PerOpEnginePrefType: TypeAlias = Dict[int, Dict[str, List[str]]]
|
||||
_EnginePrefType: TypeAlias = Dict[int, List[str]]
|
||||
|
||||
Int8Tensor = collections.namedtuple(
|
||||
'Int8Tensor', ['data', 'scale', 'zero_point']
|
||||
)
|
||||
|
||||
|
||||
class _HasProto(Protocol):
|
||||
def Proto(self) -> Any: ...
|
||||
|
||||
|
||||
class TensorCPU:
|
||||
def init(self, dims: List[int], caffe_type: int) -> None: ...
|
||||
def to_torch(self) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class Blob:
|
||||
def feed(
|
||||
self,
|
||||
arg: Any,
|
||||
device_option: Union[
|
||||
None, str, bytes, google.protobuf.message.Message, _HasProto,
|
||||
] = None,
|
||||
) -> bool: ...
|
||||
def is_tensor(self) -> bool: ...
|
||||
def as_tensor(self) -> TensorCPU: ...
|
||||
def tensor(self) -> TensorCPU: ...
|
||||
def to_torch(self) -> torch.Tensor: ...
|
||||
def fetch(self) -> Any: ...
|
||||
|
||||
|
||||
class Net:
|
||||
def run(self) -> None: ...
|
||||
def cancel(self) -> None: ...
|
||||
|
||||
|
||||
class Workspace:
|
||||
@overload
|
||||
def __init__(self) -> None: ...
|
||||
@overload
|
||||
def __init__(self, workspace: Workspace) -> None: ...
|
||||
@property
|
||||
def blobs(self) -> Dict[str, Blob]: ...
|
||||
def create_blob(self, name: _PybindStr) -> Blob: ...
|
||||
def fetch_blob(self, name: _PybindStr) -> Any: ...
|
||||
def fetch_int8_blob(
|
||||
self, name: Union[str, bytes, core.BlobReference]
|
||||
) -> Int8Tensor: ...
|
||||
def _create_net(self, _def: bytes, overwrite: bool) -> Net: ...
|
||||
def create_net(
|
||||
self,
|
||||
net: Union[str, bytes, core.Net, caffe2_pb2.NetDef],
|
||||
overwrite: bool = False,
|
||||
) -> Net: ...
|
||||
def _run_net(self, _def: bytes) -> None: ...
|
||||
def _run_operator(self, _def: bytes) -> None: ...
|
||||
def _run_plan(self, _def: bytes) -> None: ...
|
||||
def run(
|
||||
self,
|
||||
obj: Union[
|
||||
caffe2_pb2.PlanDef,
|
||||
caffe2_pb2.NetDef,
|
||||
caffe2_pb2.OperatorDef,
|
||||
_HasProto,
|
||||
],
|
||||
) -> None: ...
|
||||
def feed_blob(
|
||||
self,
|
||||
name: Union[str, bytes, core.BlobReference],
|
||||
arr: Union[caffe2_pb2.TensorProto, np.ndarray],
|
||||
device_option: Optional[caffe2_pb2.DeviceOption] = None,
|
||||
) -> bool: ...
|
||||
def remove_blob(self, blob: Any) -> None: ...
|
||||
|
||||
current: Workspace
|
||||
|
||||
|
||||
class Argument:
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def description(self) -> str: ...
|
||||
@property
|
||||
def required(self) -> bool: ...
|
||||
|
||||
|
||||
class OpSchema:
|
||||
@staticmethod
|
||||
def get(key: str) -> OpSchema: ...
|
||||
@property
|
||||
def args(self) -> List[Argument]: ...
|
||||
@property
|
||||
def input_desc(self) -> List[Tuple[str, str]]: ...
|
||||
@property
|
||||
def output_desc(self) -> List[Tuple[str, str]]: ...
|
||||
@property
|
||||
def max_input(self) -> int: ...
|
||||
@property
|
||||
def max_output(self) -> int: ...
|
||||
@property
|
||||
def min_input(self) -> int: ...
|
||||
@property
|
||||
def min_output(self) -> int: ...
|
||||
def inplace_enforced(self, x: int, y: int) -> bool: ...
|
||||
|
||||
|
||||
class DummyName:
|
||||
...
|
||||
|
||||
|
||||
class Graph:
|
||||
...
|
||||
|
||||
|
||||
class Node:
|
||||
...
|
||||
|
||||
|
||||
class Edge:
|
||||
...
|
||||
|
||||
|
||||
class NeuralNetOperator:
|
||||
...
|
||||
|
||||
|
||||
class NeuralNetData:
|
||||
...
|
||||
|
||||
|
||||
class NNSubgraph:
|
||||
...
|
||||
|
||||
|
||||
class NNMatchGraph:
|
||||
...
|
||||
|
||||
|
||||
class Annotation:
|
||||
...
|
||||
|
||||
|
||||
is_asan: bool
|
||||
has_mkldnn: bool
|
||||
use_mkldnn: bool
|
||||
has_fbgemm: bool
|
||||
use_rocm: bool
|
||||
use_trt: bool
|
||||
define_caffe2_no_operator_schema: bool
|
||||
|
||||
def registered_dbs() -> List[str]: ...
|
||||
def get_build_options() -> Dict[str, str]: ...
|
||||
def set_per_op_engine_pref(pref: _PerOpEnginePrefType) -> None: ...
|
||||
def set_global_engine_pref(pref: _EnginePrefType) -> None: ...
|
||||
def set_engine_pref(
|
||||
per_op_pref: _PerOpEnginePrefType, global_pref: _EnginePrefType
|
||||
) -> None: ...
|
||||
def set_op_engine_pref(
|
||||
op_type: _PybindStr, op_pref: _EnginePrefType
|
||||
) -> None: ...
|
||||
def op_registry_key(op_type: _PybindStr, engine: _PybindStr) -> str: ...
|
||||
def global_init(args: List[str]) -> None: ...
|
||||
def registered_operators() -> List[str]: ...
|
||||
def on_module_exit() -> None: ...
|
||||
@overload
|
||||
def switch_workspace(ws: Workspace): ...
|
||||
@overload
|
||||
def switch_workspace(name: _PybindStr, create_if_missing: Optional[bool] = None): ...
|
||||
def create_child_workspace(
|
||||
parent_ws_name: _PybindStr, child_ws_name: _PybindStr
|
||||
) -> None: ...
|
||||
def root_folder() -> str: ...
|
||||
def current_workspace() -> str: ...
|
||||
def workspaces() -> List[str]: ...
|
||||
def benchmark_net(
|
||||
name: _PybindStr, warmup_runs: int, main_runs: int, run_individual: bool
|
||||
) -> List[float]: ...
|
||||
def benchmark_net_once(name: _PybindStr) -> float: ...
|
||||
|
||||
def blobs() -> Dict[str, Blob]: ...
|
||||
def has_blob(name: _PybindStr) -> bool: ...
|
||||
def create_blob(name: _PybindStr) -> bool: ...
|
||||
def reset_blob(name: _PybindStr) -> None: ...
|
||||
@overload
|
||||
def deserialize_blob(content: _PybindStr) -> Blob: ...
|
||||
@overload
|
||||
def deserialize_blob(name: _PybindStr, serialized: bytes) -> None: ...
|
||||
def serialize_blob(name: _PybindStr) -> bytes: ...
|
||||
|
||||
def get_stats() -> Dict[str, int]: ...
|
||||
def is_numa_enabled() -> bool: ...
|
||||
def get_num_numa_nodes() -> int: ...
|
||||
def get_blob_numa_node(blob_name: _PybindStr) -> int: ...
|
||||
def get_blob_size_bytes(blob_name: _PybindStr) -> int: ...
|
||||
def create_offline_tensor(
|
||||
name: _PybindStr, dims: List[int], datatype: int
|
||||
) -> bool: ...
|
||||
def fakeFp16FuseOps(net_str: bytes) -> bytes: ...
|
||||
|
||||
def num_cuda_devices() -> int: ...
|
||||
def get_cuda_version() -> int: ...
|
||||
def get_cudnn_version() -> int: ...
|
||||
def get_gpu_memory_info(device_id: int) -> Tuple[int, int]: ...
|
||||
def get_device_properties(deviceid: int) -> Dict[str, Any]: ...
|
||||
|
||||
def num_hip_devices() -> int: ...
|
||||
def get_hip_version() -> int: ...
|
||||
def get_miopen_version() -> int: ...
|
||||
|
||||
has_hip_support: bool
|
||||
has_cuda_support: bool
|
||||
has_gpu_support: bool
|
||||
87
caffe2/python/allcompare_test.py
Normal file
87
caffe2/python/allcompare_test.py
Normal file
@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from hypothesis import given, settings
|
||||
import hypothesis.strategies as st
|
||||
from multiprocessing import Process
|
||||
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
|
||||
op_engine = 'GLOO'
|
||||
|
||||
|
||||
class TemporaryDirectory:
|
||||
def __enter__(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
return self.tmpdir
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
shutil.rmtree(self.tmpdir)
|
||||
|
||||
|
||||
def allcompare_process(filestore_dir, process_id, data, num_procs):
|
||||
from caffe2.python import core, data_parallel_model, workspace, dyndep
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
from caffe2.proto import caffe2_pb2
|
||||
dyndep.InitOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
|
||||
|
||||
workspace.RunOperatorOnce(
|
||||
core.CreateOperator(
|
||||
"FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
|
||||
)
|
||||
)
|
||||
rendezvous = dict(
|
||||
kv_handler="store_handler",
|
||||
shard_id=process_id,
|
||||
num_shards=num_procs,
|
||||
engine=op_engine,
|
||||
exit_nets=None
|
||||
)
|
||||
|
||||
model = ModelHelper()
|
||||
model._rendezvous = rendezvous
|
||||
|
||||
workspace.FeedBlob("test_data", data)
|
||||
|
||||
data_parallel_model._RunComparison(
|
||||
model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
|
||||
)
|
||||
|
||||
|
||||
class TestAllCompare(hu.HypothesisTestCase):
|
||||
@given(
|
||||
d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
|
||||
)
|
||||
@settings(deadline=10000)
|
||||
def test_allcompare(self, d, n, num_procs):
|
||||
dims = []
|
||||
for _ in range(d):
|
||||
dims.append(np.random.randint(1, high=n))
|
||||
test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)
|
||||
|
||||
with TemporaryDirectory() as tempdir:
|
||||
processes = []
|
||||
for idx in range(num_procs):
|
||||
process = Process(
|
||||
target=allcompare_process,
|
||||
args=(tempdir, idx, test_data, num_procs)
|
||||
)
|
||||
processes.append(process)
|
||||
process.start()
|
||||
|
||||
while len(processes) > 0:
|
||||
process = processes.pop()
|
||||
process.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
424
caffe2/python/attention.py
Normal file
424
caffe2/python/attention.py
Normal file
@ -0,0 +1,424 @@
|
||||
## @package attention
|
||||
# Module caffe2.python.attention
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import brew
|
||||
|
||||
|
||||
class AttentionType:
|
||||
Regular, Recurrent, Dot, SoftCoverage = tuple(range(4))
|
||||
|
||||
|
||||
def s(scope, name):
|
||||
# We have to manually scope due to our internal/external blob
|
||||
# relationships.
|
||||
return "{}/{}".format(str(scope), str(name))
|
||||
|
||||
|
||||
# c_i = \sum_j w_{ij}\textbf{s}_j
|
||||
def _calc_weighted_context(
|
||||
model,
|
||||
encoder_outputs_transposed,
|
||||
encoder_output_dim,
|
||||
attention_weights_3d,
|
||||
scope,
|
||||
):
|
||||
# [batch_size, encoder_output_dim, 1]
|
||||
attention_weighted_encoder_context = brew.batch_mat_mul(
|
||||
model,
|
||||
[encoder_outputs_transposed, attention_weights_3d],
|
||||
s(scope, 'attention_weighted_encoder_context'),
|
||||
)
|
||||
# [batch_size, encoder_output_dim]
|
||||
attention_weighted_encoder_context, _ = model.net.Reshape(
|
||||
attention_weighted_encoder_context,
|
||||
[
|
||||
attention_weighted_encoder_context,
|
||||
s(scope, 'attention_weighted_encoder_context_old_shape'),
|
||||
],
|
||||
shape=[1, -1, encoder_output_dim],
|
||||
)
|
||||
return attention_weighted_encoder_context
|
||||
|
||||
|
||||
# Calculate a softmax over the passed in attention energy logits
|
||||
def _calc_attention_weights(
|
||||
model,
|
||||
attention_logits_transposed,
|
||||
scope,
|
||||
encoder_lengths=None,
|
||||
):
|
||||
if encoder_lengths is not None:
|
||||
attention_logits_transposed = model.net.SequenceMask(
|
||||
[attention_logits_transposed, encoder_lengths],
|
||||
['masked_attention_logits'],
|
||||
mode='sequence',
|
||||
)
|
||||
|
||||
# [batch_size, encoder_length, 1]
|
||||
attention_weights_3d = brew.softmax(
|
||||
model,
|
||||
attention_logits_transposed,
|
||||
s(scope, 'attention_weights_3d'),
|
||||
engine='CUDNN',
|
||||
axis=1,
|
||||
)
|
||||
return attention_weights_3d
|
||||
|
||||
|
||||
# e_{ij} = \textbf{v}^T tanh \alpha(\textbf{h}_{i-1}, \textbf{s}_j)
|
||||
def _calc_attention_logits_from_sum_match(
|
||||
model,
|
||||
decoder_hidden_encoder_outputs_sum,
|
||||
encoder_output_dim,
|
||||
scope,
|
||||
):
|
||||
# [encoder_length, batch_size, encoder_output_dim]
|
||||
decoder_hidden_encoder_outputs_sum = model.net.Tanh(
|
||||
decoder_hidden_encoder_outputs_sum,
|
||||
decoder_hidden_encoder_outputs_sum,
|
||||
)
|
||||
|
||||
# [encoder_length, batch_size, 1]
|
||||
attention_logits = brew.fc(
|
||||
model,
|
||||
decoder_hidden_encoder_outputs_sum,
|
||||
s(scope, 'attention_logits'),
|
||||
dim_in=encoder_output_dim,
|
||||
dim_out=1,
|
||||
axis=2,
|
||||
freeze_bias=True,
|
||||
)
|
||||
|
||||
# [batch_size, encoder_length, 1]
|
||||
attention_logits_transposed = brew.transpose(
|
||||
model,
|
||||
attention_logits,
|
||||
s(scope, 'attention_logits_transposed'),
|
||||
axes=[1, 0, 2],
|
||||
)
|
||||
return attention_logits_transposed
|
||||
|
||||
|
||||
# \textbf{W}^\alpha used in the context of \alpha_{sum}(a,b)
|
||||
def _apply_fc_weight_for_sum_match(
|
||||
model,
|
||||
input,
|
||||
dim_in,
|
||||
dim_out,
|
||||
scope,
|
||||
name,
|
||||
):
|
||||
output = brew.fc(
|
||||
model,
|
||||
input,
|
||||
s(scope, name),
|
||||
dim_in=dim_in,
|
||||
dim_out=dim_out,
|
||||
axis=2,
|
||||
)
|
||||
output = model.net.Squeeze(
|
||||
output,
|
||||
output,
|
||||
dims=[0],
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
# Implement RecAtt due to section 4.1 in http://arxiv.org/abs/1601.03317
|
||||
def apply_recurrent_attention(
|
||||
model,
|
||||
encoder_output_dim,
|
||||
encoder_outputs_transposed,
|
||||
weighted_encoder_outputs,
|
||||
decoder_hidden_state_t,
|
||||
decoder_hidden_state_dim,
|
||||
attention_weighted_encoder_context_t_prev,
|
||||
scope,
|
||||
encoder_lengths=None,
|
||||
):
|
||||
weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
|
||||
model=model,
|
||||
input=attention_weighted_encoder_context_t_prev,
|
||||
dim_in=encoder_output_dim,
|
||||
dim_out=encoder_output_dim,
|
||||
scope=scope,
|
||||
name='weighted_prev_attention_context',
|
||||
)
|
||||
|
||||
weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
|
||||
model=model,
|
||||
input=decoder_hidden_state_t,
|
||||
dim_in=decoder_hidden_state_dim,
|
||||
dim_out=encoder_output_dim,
|
||||
scope=scope,
|
||||
name='weighted_decoder_hidden_state',
|
||||
)
|
||||
# [1, batch_size, encoder_output_dim]
|
||||
decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
|
||||
[
|
||||
weighted_prev_attention_context,
|
||||
weighted_decoder_hidden_state,
|
||||
],
|
||||
s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
|
||||
)
|
||||
# [encoder_length, batch_size, encoder_output_dim]
|
||||
decoder_hidden_encoder_outputs_sum = model.net.Add(
|
||||
[
|
||||
weighted_encoder_outputs,
|
||||
decoder_hidden_encoder_outputs_sum_tmp,
|
||||
],
|
||||
s(scope, 'decoder_hidden_encoder_outputs_sum'),
|
||||
broadcast=1,
|
||||
)
|
||||
attention_logits_transposed = _calc_attention_logits_from_sum_match(
|
||||
model=model,
|
||||
decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
|
||||
encoder_output_dim=encoder_output_dim,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
# [batch_size, encoder_length, 1]
|
||||
attention_weights_3d = _calc_attention_weights(
|
||||
model=model,
|
||||
attention_logits_transposed=attention_logits_transposed,
|
||||
scope=scope,
|
||||
encoder_lengths=encoder_lengths,
|
||||
)
|
||||
|
||||
# [batch_size, encoder_output_dim, 1]
|
||||
attention_weighted_encoder_context = _calc_weighted_context(
|
||||
model=model,
|
||||
encoder_outputs_transposed=encoder_outputs_transposed,
|
||||
encoder_output_dim=encoder_output_dim,
|
||||
attention_weights_3d=attention_weights_3d,
|
||||
scope=scope,
|
||||
)
|
||||
return attention_weighted_encoder_context, attention_weights_3d, [
|
||||
decoder_hidden_encoder_outputs_sum,
|
||||
]
|
||||
|
||||
|
||||
def apply_regular_attention(
|
||||
model,
|
||||
encoder_output_dim,
|
||||
encoder_outputs_transposed,
|
||||
weighted_encoder_outputs,
|
||||
decoder_hidden_state_t,
|
||||
decoder_hidden_state_dim,
|
||||
scope,
|
||||
encoder_lengths=None,
|
||||
):
|
||||
weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
|
||||
model=model,
|
||||
input=decoder_hidden_state_t,
|
||||
dim_in=decoder_hidden_state_dim,
|
||||
dim_out=encoder_output_dim,
|
||||
scope=scope,
|
||||
name='weighted_decoder_hidden_state',
|
||||
)
|
||||
|
||||
# [encoder_length, batch_size, encoder_output_dim]
|
||||
decoder_hidden_encoder_outputs_sum = model.net.Add(
|
||||
[weighted_encoder_outputs, weighted_decoder_hidden_state],
|
||||
s(scope, 'decoder_hidden_encoder_outputs_sum'),
|
||||
broadcast=1,
|
||||
use_grad_hack=1,
|
||||
)
|
||||
|
||||
attention_logits_transposed = _calc_attention_logits_from_sum_match(
|
||||
model=model,
|
||||
decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
|
||||
encoder_output_dim=encoder_output_dim,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
# [batch_size, encoder_length, 1]
|
||||
attention_weights_3d = _calc_attention_weights(
|
||||
model=model,
|
||||
attention_logits_transposed=attention_logits_transposed,
|
||||
scope=scope,
|
||||
encoder_lengths=encoder_lengths,
|
||||
)
|
||||
|
||||
# [batch_size, encoder_output_dim, 1]
|
||||
attention_weighted_encoder_context = _calc_weighted_context(
|
||||
model=model,
|
||||
encoder_outputs_transposed=encoder_outputs_transposed,
|
||||
encoder_output_dim=encoder_output_dim,
|
||||
attention_weights_3d=attention_weights_3d,
|
||||
scope=scope,
|
||||
)
|
||||
return attention_weighted_encoder_context, attention_weights_3d, [
|
||||
decoder_hidden_encoder_outputs_sum,
|
||||
]
|
||||
|
||||
|
||||
def apply_dot_attention(
|
||||
model,
|
||||
encoder_output_dim,
|
||||
# [batch_size, encoder_output_dim, encoder_length]
|
||||
encoder_outputs_transposed,
|
||||
# [1, batch_size, decoder_state_dim]
|
||||
decoder_hidden_state_t,
|
||||
decoder_hidden_state_dim,
|
||||
scope,
|
||||
encoder_lengths=None,
|
||||
):
|
||||
if decoder_hidden_state_dim != encoder_output_dim:
|
||||
weighted_decoder_hidden_state = brew.fc(
|
||||
model,
|
||||
decoder_hidden_state_t,
|
||||
s(scope, 'weighted_decoder_hidden_state'),
|
||||
dim_in=decoder_hidden_state_dim,
|
||||
dim_out=encoder_output_dim,
|
||||
axis=2,
|
||||
)
|
||||
else:
|
||||
weighted_decoder_hidden_state = decoder_hidden_state_t
|
||||
|
||||
# [batch_size, decoder_state_dim]
|
||||
squeezed_weighted_decoder_hidden_state = model.net.Squeeze(
|
||||
weighted_decoder_hidden_state,
|
||||
s(scope, 'squeezed_weighted_decoder_hidden_state'),
|
||||
dims=[0],
|
||||
)
|
||||
|
||||
# [batch_size, decoder_state_dim, 1]
|
||||
expanddims_squeezed_weighted_decoder_hidden_state = model.net.ExpandDims(
|
||||
squeezed_weighted_decoder_hidden_state,
|
||||
squeezed_weighted_decoder_hidden_state,
|
||||
dims=[2],
|
||||
)
|
||||
|
||||
# [batch_size, encoder_output_dim, 1]
|
||||
attention_logits_transposed = model.net.BatchMatMul(
|
||||
[
|
||||
encoder_outputs_transposed,
|
||||
expanddims_squeezed_weighted_decoder_hidden_state,
|
||||
],
|
||||
s(scope, 'attention_logits'),
|
||||
trans_a=1,
|
||||
)
|
||||
|
||||
# [batch_size, encoder_length, 1]
|
||||
attention_weights_3d = _calc_attention_weights(
|
||||
model=model,
|
||||
attention_logits_transposed=attention_logits_transposed,
|
||||
scope=scope,
|
||||
encoder_lengths=encoder_lengths,
|
||||
)
|
||||
|
||||
# [batch_size, encoder_output_dim, 1]
|
||||
attention_weighted_encoder_context = _calc_weighted_context(
|
||||
model=model,
|
||||
encoder_outputs_transposed=encoder_outputs_transposed,
|
||||
encoder_output_dim=encoder_output_dim,
|
||||
attention_weights_3d=attention_weights_3d,
|
||||
scope=scope,
|
||||
)
|
||||
return attention_weighted_encoder_context, attention_weights_3d, []
|
||||
|
||||
|
||||
def apply_soft_coverage_attention(
|
||||
model,
|
||||
encoder_output_dim,
|
||||
encoder_outputs_transposed,
|
||||
weighted_encoder_outputs,
|
||||
decoder_hidden_state_t,
|
||||
decoder_hidden_state_dim,
|
||||
scope,
|
||||
encoder_lengths,
|
||||
coverage_t_prev,
|
||||
coverage_weights,
|
||||
):
|
||||
|
||||
weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
|
||||
model=model,
|
||||
input=decoder_hidden_state_t,
|
||||
dim_in=decoder_hidden_state_dim,
|
||||
dim_out=encoder_output_dim,
|
||||
scope=scope,
|
||||
name='weighted_decoder_hidden_state',
|
||||
)
|
||||
|
||||
# [encoder_length, batch_size, encoder_output_dim]
|
||||
decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
|
||||
[weighted_encoder_outputs, weighted_decoder_hidden_state],
|
||||
s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
|
||||
broadcast=1,
|
||||
)
|
||||
# [batch_size, encoder_length]
|
||||
coverage_t_prev_2d = model.net.Squeeze(
|
||||
coverage_t_prev,
|
||||
s(scope, 'coverage_t_prev_2d'),
|
||||
dims=[0],
|
||||
)
|
||||
# [encoder_length, batch_size]
|
||||
coverage_t_prev_transposed = brew.transpose(
|
||||
model,
|
||||
coverage_t_prev_2d,
|
||||
s(scope, 'coverage_t_prev_transposed'),
|
||||
)
|
||||
|
||||
# [encoder_length, batch_size, encoder_output_dim]
|
||||
scaled_coverage_weights = model.net.Mul(
|
||||
[coverage_weights, coverage_t_prev_transposed],
|
||||
s(scope, 'scaled_coverage_weights'),
|
||||
broadcast=1,
|
||||
axis=0,
|
||||
)
|
||||
|
||||
# [encoder_length, batch_size, encoder_output_dim]
|
||||
decoder_hidden_encoder_outputs_sum = model.net.Add(
|
||||
[decoder_hidden_encoder_outputs_sum_tmp, scaled_coverage_weights],
|
||||
s(scope, 'decoder_hidden_encoder_outputs_sum'),
|
||||
)
|
||||
|
||||
# [batch_size, encoder_length, 1]
|
||||
attention_logits_transposed = _calc_attention_logits_from_sum_match(
|
||||
model=model,
|
||||
decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
|
||||
encoder_output_dim=encoder_output_dim,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
# [batch_size, encoder_length, 1]
|
||||
attention_weights_3d = _calc_attention_weights(
|
||||
model=model,
|
||||
attention_logits_transposed=attention_logits_transposed,
|
||||
scope=scope,
|
||||
encoder_lengths=encoder_lengths,
|
||||
)
|
||||
|
||||
# [batch_size, encoder_output_dim, 1]
|
||||
attention_weighted_encoder_context = _calc_weighted_context(
|
||||
model=model,
|
||||
encoder_outputs_transposed=encoder_outputs_transposed,
|
||||
encoder_output_dim=encoder_output_dim,
|
||||
attention_weights_3d=attention_weights_3d,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
# [batch_size, encoder_length]
|
||||
attention_weights_2d = model.net.Squeeze(
|
||||
attention_weights_3d,
|
||||
s(scope, 'attention_weights_2d'),
|
||||
dims=[2],
|
||||
)
|
||||
|
||||
coverage_t = model.net.Add(
|
||||
[coverage_t_prev, attention_weights_2d],
|
||||
s(scope, 'coverage_t'),
|
||||
broadcast=1,
|
||||
)
|
||||
|
||||
return (
|
||||
attention_weighted_encoder_context,
|
||||
attention_weights_3d,
|
||||
[decoder_hidden_encoder_outputs_sum],
|
||||
coverage_t,
|
||||
)
|
||||
137
caffe2/python/benchmark_generator.py
Normal file
137
caffe2/python/benchmark_generator.py
Normal file
@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import string
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
from caffe2.python.predictor import mobile_exporter
|
||||
from caffe2.python import core, workspace, brew, utils
|
||||
|
||||
|
||||
def parse_kwarg(kwarg_str):
|
||||
key, value = map(string.strip, kwarg_str.split("=", 1))
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
value = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
return key, value
|
||||
|
||||
|
||||
def main(args):
|
||||
# User defined keyword arguments
|
||||
kwargs = {"order": "NCHW"}
|
||||
kwargs.update(dict(args.kwargs))
|
||||
|
||||
model = ModelHelper(name=args.benchmark_name)
|
||||
|
||||
op_type = args.operator # assumes a brew type op name
|
||||
input_name = args.input_name
|
||||
output_name = args.output_name
|
||||
|
||||
iters = int(args.iters)
|
||||
for i in range(iters):
|
||||
input_blob_name = input_name + (str(i) if i > 0 and args.chain else '')
|
||||
output_blob_name = output_name + str(i + 1)
|
||||
add_op = getattr(brew, op_type)
|
||||
add_op(model, input_blob_name, output_blob_name, **kwargs)
|
||||
if args.chain:
|
||||
input_name, output_name = output_name, input_name
|
||||
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
extra_init_net_ops = []
|
||||
|
||||
def make_blob_on_context(blob_name, blob_data, context):
|
||||
if context.upper() != "CPU":
|
||||
blob_name_modified = "{}_CPU".format(blob_name)
|
||||
else: # CPU case is simple
|
||||
blob_name_modified = blob_name
|
||||
|
||||
fill_op = core.CreateOperator(
|
||||
"GivenTensorFill", [], [blob_name_modified],
|
||||
arg=[
|
||||
utils.MakeArgument("shape", blob_data.shape),
|
||||
utils.MakeArgument("values", blob_data)
|
||||
]
|
||||
)
|
||||
extra_init_net_ops.append(fill_op)
|
||||
|
||||
# We need to create CPU blobs and add some copy operations in
|
||||
# the init_net
|
||||
if context.upper() == "OPENGL":
|
||||
copy_op = core.CreateOperator("CopyToOpenGL", [blob_name_modified],
|
||||
[blob_name])
|
||||
extra_init_net_ops.append(copy_op)
|
||||
|
||||
for unparsed_blob in args.blob:
|
||||
name, unparsed_dims = unparsed_blob.split('=')
|
||||
dims = [int(d) for d in unparsed_dims.split(',')]
|
||||
np_input = np.random.rand(*dims).astype(np.float32)
|
||||
make_blob_on_context(name, np_input, args.context)
|
||||
|
||||
init_net, predict_net = mobile_exporter.Export(
|
||||
workspace, model.net, model.params
|
||||
)
|
||||
init_net.op.extend(extra_init_net_ops)
|
||||
|
||||
# Handle manual rewrite
|
||||
if args.context.upper() == "OPENGL":
|
||||
old_ops = [op for op in predict_net.op]
|
||||
del predict_net.op[:]
|
||||
for op in old_ops:
|
||||
op.type = 'OpenGL{}'.format(op.type)
|
||||
predict_net.op.extend(old_ops)
|
||||
|
||||
if args.debug:
|
||||
print("init_net:")
|
||||
for op in init_net.op:
|
||||
print(" ", op.type, op.input, "-->", op.output)
|
||||
print("predict_net:")
|
||||
for op in predict_net.op:
|
||||
print(" ", op.type, op.input, "-->", op.output)
|
||||
|
||||
with open(args.predict_net, 'wb') as f:
|
||||
f.write(predict_net.SerializeToString())
|
||||
with open(args.init_net, 'wb') as f:
|
||||
f.write(init_net.SerializeToString())
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Utility to generate Caffe2 benchmark models.")
|
||||
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
|
||||
parser.add_argument("-b", "--blob",
|
||||
help="Instantiate a blob --blob name=dim1,dim2,dim3",
|
||||
action='append')
|
||||
parser.add_argument("--context", help="Context to run on.", default="CPU")
|
||||
parser.add_argument("--kwargs", help="kwargs to pass to operator.",
|
||||
nargs="*", type=parse_kwarg, default=[])
|
||||
parser.add_argument("--init_net", help="Output initialization net.",
|
||||
default="init_net.pb")
|
||||
parser.add_argument("--predict_net", help="Output prediction net.",
|
||||
default="predict_net.pb")
|
||||
parser.add_argument("--benchmark_name",
|
||||
help="Name of the benchmark network",
|
||||
default="benchmark")
|
||||
parser.add_argument("--input_name", help="Name of the input blob.",
|
||||
default="data")
|
||||
parser.add_argument("--output_name", help="Name of the output blob.",
|
||||
default="output")
|
||||
parser.add_argument("--iters",
|
||||
help="Number of iterations to run the operator.",
|
||||
default="1")
|
||||
parser.add_argument("-d", "--debug", help="Print debug information.",
|
||||
action='store_true')
|
||||
parser.add_argument("-c", "--chain",
|
||||
help="Chain ops together (create data dependencies)",
|
||||
action='store_true')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
31
caffe2/python/benchmarks/concat_benchmark.py
Normal file
31
caffe2/python/benchmarks/concat_benchmark.py
Normal file
@ -0,0 +1,31 @@
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from caffe2.python import core, workspace
|
||||
|
||||
|
||||
def benchmark_concat(num_inputs, input_dim, axis, add_axis, iterations):
|
||||
input_names = [f"input{i}" for i in range(num_inputs)]
|
||||
for n in input_names:
|
||||
workspace.FeedBlob(n, np.random.randn(*input_dim).astype(np.float32))
|
||||
|
||||
net = core.Net("benchmark_net")
|
||||
net.Concat(input_names, ["output", "split_info"], axis=axis, add_axis=add_axis)
|
||||
workspace.CreateNet(net)
|
||||
|
||||
runtimes = workspace.BenchmarkNet(net.Name(), 1, iterations, True)
|
||||
print(f"{num_inputs * np.prod(input_dim) * 4 / runtimes[1] / 1e6} GB/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="minimal benchmark for concat.")
|
||||
parser.add_argument("--num_inputs", type=int, default=2)
|
||||
parser.add_argument("--input_dim", nargs="+", type=int, required=True)
|
||||
parser.add_argument("--axis", type=int, default=-1)
|
||||
parser.add_argument("--add_axis", type=int, default=0)
|
||||
parser.add_argument("--iterations", type=int, default=64)
|
||||
args, extra_args = parser.parse_known_args()
|
||||
core.GlobalInit(["python"] + extra_args)
|
||||
benchmark_concat(
|
||||
args.num_inputs, args.input_dim, args.axis, args.add_axis, args.iterations
|
||||
)
|
||||
@ -0,0 +1,50 @@
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from caffe2.python import core, workspace
|
||||
|
||||
|
||||
def main(bit_rate):
|
||||
# uncomment for debugging
|
||||
# np.random.seed(0)
|
||||
batchsize = 10 * 1000
|
||||
blocksize = 64
|
||||
print(batchsize, blocksize)
|
||||
input_data = np.random.rand(batchsize, blocksize).astype(np.float32)
|
||||
|
||||
workspace.FeedBlob("input_data", input_data)
|
||||
|
||||
net = core.Net("bench")
|
||||
op = core.CreateOperator(
|
||||
"FloatToFused" + str(bit_rate) + "BitRowwiseQuantized",
|
||||
"input_data",
|
||||
"quantized_data",
|
||||
engine="GREEDY",
|
||||
)
|
||||
net.Proto().op.extend([op])
|
||||
workspace.GlobalInit(["caffe2", "--caffe2_log_level=0"])
|
||||
workspace.CreateNet(net)
|
||||
iterations = 10
|
||||
workspace.BenchmarkNet(net.Proto().name, 1, iterations, True)
|
||||
|
||||
net2 = core.Net("bench2")
|
||||
op = core.CreateOperator(
|
||||
"FloatToFused" + str(bit_rate) + "BitRowwiseQuantized",
|
||||
"input_data",
|
||||
"quantized_data",
|
||||
)
|
||||
net2.Proto().op.extend([op])
|
||||
|
||||
workspace.CreateNet(net2)
|
||||
workspace.BenchmarkNet(net2.Proto().name, 1, iterations, True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="benchmark for row-wise 2/4-bit quantization."
|
||||
)
|
||||
parser.add_argument("--bit-rate", type=int, default=4)
|
||||
args = parser.parse_args()
|
||||
main(args.bit_rate)
|
||||
117
caffe2/python/benchmarks/sparse_lengths_sum_nbit_benchmark.py
Normal file
117
caffe2/python/benchmarks/sparse_lengths_sum_nbit_benchmark.py
Normal file
@ -0,0 +1,117 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
from caffe2.python import core, workspace
|
||||
|
||||
|
||||
def benchmark_sparse_lengths_sum(
|
||||
categorical_limit,
|
||||
embedding_size,
|
||||
average_len,
|
||||
batch_size,
|
||||
iterations,
|
||||
flush_cache,
|
||||
bit_rate=st.sampled_from([2, 4]),
|
||||
):
|
||||
print("Preparing lookup table. " + str(datetime.datetime.now()))
|
||||
|
||||
# We will use a constant, but non-trivial value so we save initialization
|
||||
# time.
|
||||
data = np.ones([categorical_limit, embedding_size], dtype=np.float32)
|
||||
data *= 17.01
|
||||
|
||||
init_net = core.Net("init_net")
|
||||
op = core.CreateOperator(
|
||||
"FloatToFused" + str(bit_rate) + "BitRowwiseQuantized", "X", "X_q"
|
||||
)
|
||||
init_net.Proto().op.extend([op])
|
||||
workspace.FeedBlob("X", data)
|
||||
|
||||
print("Data has shape {} {}".format(data.shape, datetime.datetime.now()))
|
||||
|
||||
# In order to produce truly random lengths and indices, we will embed a
|
||||
# Python operator in the net to generate them.
|
||||
def f(_, outputs):
|
||||
lengths = np.random.randint(
|
||||
int(average_len * 0.75), int(average_len * 1.25), batch_size
|
||||
).astype(np.int32)
|
||||
indices = np.random.randint(0, categorical_limit, np.sum(lengths)).astype(
|
||||
np.int64
|
||||
)
|
||||
outputs[0].feed(indices)
|
||||
outputs[1].feed(lengths)
|
||||
|
||||
init_net.Python(f)([], ["indices", "lengths"])
|
||||
workspace.RunNetOnce(init_net)
|
||||
|
||||
net = core.Net("mynet")
|
||||
if flush_cache:
|
||||
l3_cache_size = 30 * 2 ** 20 // 4
|
||||
workspace.FeedBlob(
|
||||
"huge_blob", np.random.randn(l3_cache_size).astype(np.float32)
|
||||
)
|
||||
net.Scale("huge_blob", "huge_blob_2x", value=2.0)
|
||||
op = core.CreateOperator(
|
||||
"SparseLengthsSumFused" + str(bit_rate) + "BitRowwise",
|
||||
["X_q", "indices", "lengths"],
|
||||
"Y",
|
||||
)
|
||||
net.Proto().op.extend([op])
|
||||
workspace.CreateNet(net)
|
||||
|
||||
# Set random seed, so that repeated runs will keep the same sequence of
|
||||
# random indices.
|
||||
np.random.seed(1701)
|
||||
|
||||
print("Preparation finished. " + str(datetime.datetime.now()))
|
||||
|
||||
runtimes = workspace.BenchmarkNet(net.Name(), 1, iterations, True)
|
||||
print(
|
||||
"{} billion sums per sec".format(
|
||||
embedding_size
|
||||
* workspace.FetchBlob("indices").size
|
||||
/ runtimes[2 if flush_cache else 1]
|
||||
/ 1e6
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="minimal benchmark for sparse lengths sum."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--embedding-size", type=int, default=6000000, help="Lookup table size."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-dim", type=int, default=128, help="Embedding dimension."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--average_len",
|
||||
type=int,
|
||||
default=27,
|
||||
help="Sparse feature average lengths, default is 27",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=100, help="The batch size.")
|
||||
parser.add_argument(
|
||||
"-i", "--iteration", type=int, default=100000, help="The number of iterations."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flush-cache", action="store_true", help="If true, flush cache"
|
||||
)
|
||||
parser.add_argument("--bit-rate", type=int, default=4)
|
||||
args, extra_args = parser.parse_known_args()
|
||||
core.GlobalInit(["python"] + extra_args)
|
||||
benchmark_sparse_lengths_sum(
|
||||
args.embedding_size,
|
||||
args.embedding_dim,
|
||||
args.average_len,
|
||||
args.batch_size,
|
||||
args.iteration,
|
||||
args.flush_cache,
|
||||
args.bit_rate,
|
||||
)
|
||||
121
caffe2/python/benchmarks/sparse_normalize_benchmark.py
Normal file
121
caffe2/python/benchmarks/sparse_normalize_benchmark.py
Normal file
@ -0,0 +1,121 @@
|
||||
import argparse
|
||||
import datetime
|
||||
|
||||
# import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
from caffe2.python import core, workspace
|
||||
|
||||
|
||||
def benchmark_sparse_normalize(
|
||||
categorical_limit,
|
||||
embedding_size,
|
||||
average_len,
|
||||
batch_size,
|
||||
iterations,
|
||||
flush_cache,
|
||||
fp16,
|
||||
):
|
||||
print("Preparing lookup table. " + str(datetime.datetime.now()))
|
||||
|
||||
# We will use a constant, but non-trivial value so we save initialization
|
||||
# time.
|
||||
data = np.ones([categorical_limit, embedding_size], dtype=np.float32)
|
||||
data *= 17.01
|
||||
|
||||
init_net = core.Net("init_net")
|
||||
if fp16:
|
||||
op = core.CreateOperator("FloatToHalf", "X", "X_fp16")
|
||||
init_net.Proto().op.extend([op])
|
||||
l3_cache_size = 30 * 2 ** 20 // 4
|
||||
|
||||
# In order to produce truly random lengths and indices, we will embed a
|
||||
# Python operator in the net to generate them.
|
||||
def f(_, outputs):
|
||||
lengths = np.random.randint(
|
||||
int(average_len * 0.75), int(average_len * 1.25), batch_size
|
||||
).astype(np.int32)
|
||||
indices = np.random.randint(0, categorical_limit, np.sum(lengths)).astype(
|
||||
np.int64
|
||||
)
|
||||
outputs[0].feed(indices)
|
||||
|
||||
workspace.FeedBlob("X", data)
|
||||
workspace.FeedBlob("huge_blob", np.random.randn(l3_cache_size).astype(np.float32))
|
||||
|
||||
print("Data has shape {} {}".format(data.shape, datetime.datetime.now()))
|
||||
|
||||
init_net.Python(f)([], ["indices"])
|
||||
workspace.RunNetOnce(init_net)
|
||||
|
||||
net = core.Net("mynet")
|
||||
op = core.CreateOperator(
|
||||
"Float16SparseNormalize" if fp16 else "SparseNormalize",
|
||||
["X_fp16", "indices"] if fp16 else ["X", "indices"],
|
||||
"X_fp16" if fp16 else "X",
|
||||
)
|
||||
net.Proto().external_input.append("X")
|
||||
net.Proto().external_input.append("X_fp16")
|
||||
net.Proto().external_input.append("indices")
|
||||
net.Proto().op.extend([op])
|
||||
if flush_cache:
|
||||
net.Scale("huge_blob", "huge_blob_2x", value=2.0)
|
||||
|
||||
workspace.CreateNet(net)
|
||||
|
||||
# Set random seed, so that repeated runs will keep the same sequence of
|
||||
# random indices.
|
||||
np.random.seed(1701)
|
||||
|
||||
print("Preparation finished. " + str(datetime.datetime.now()))
|
||||
|
||||
runtimes = workspace.BenchmarkNet(net.Name(), 1, iterations, True)
|
||||
|
||||
print("{} ms".format(runtimes[2 if flush_cache else 1]))
|
||||
print("indice_size: " + str(workspace.FetchBlob("indices").size))
|
||||
print(
|
||||
"{} GB/sec".format(
|
||||
(2 if fp16 else 4)
|
||||
* embedding_size
|
||||
* workspace.FetchBlob("indices").size
|
||||
/ runtimes[2 if flush_cache else 1]
|
||||
/ 1e6
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="minimal benchmark for sparse lengths sum."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--embedding-size", type=int, default=600000, help="Lookup table size."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-dim", type=int, default=128, help="Embedding dimension."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--average-len",
|
||||
type=int,
|
||||
default=27,
|
||||
help="Sparse feature average lengths, default is 27",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=100, help="The batch size.")
|
||||
parser.add_argument(
|
||||
"-i", "--iteration", type=int, default=100, help="The number of iterations."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flush-cache", action="store_true", help="If true, flush cache"
|
||||
)
|
||||
parser.add_argument("--fp16", action="store_true", help="If true, use fp16")
|
||||
args, extra_args = parser.parse_known_args()
|
||||
core.GlobalInit(["python"] + extra_args)
|
||||
|
||||
benchmark_sparse_normalize(
|
||||
args.embedding_size,
|
||||
args.embedding_dim,
|
||||
args.average_len,
|
||||
args.batch_size,
|
||||
args.iteration,
|
||||
args.flush_cache,
|
||||
args.fp16,
|
||||
)
|
||||
164
caffe2/python/binarysize.py
Normal file
164
caffe2/python/binarysize.py
Normal file
@ -0,0 +1,164 @@
|
||||
"""A tool to inspect the binary size of a built binary file.
|
||||
|
||||
This script prints out a tree of symbols and their corresponding sizes, using
|
||||
Linux's nm functionality.
|
||||
|
||||
Usage:
|
||||
|
||||
python binary_size.py -- \
|
||||
--target=/path/to/your/target/binary \
|
||||
[--nm_command=/path/to/your/custom/nm] \
|
||||
[--max_depth=10] [--min_size=1024] \
|
||||
[--color] \
|
||||
|
||||
To assist visualization, pass in '--color' to make the symbols color coded to
|
||||
green, assuming that you have a xterm connection that supports color.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
class Trie:
|
||||
"""A simple class that represents a Trie."""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initializes a Trie object."""
|
||||
self.name = name
|
||||
self.size = 0
|
||||
self.dictionary = {}
|
||||
|
||||
|
||||
def GetSymbolTrie(target, nm_command, max_depth):
|
||||
"""Gets a symbol trie with the passed in target.
|
||||
|
||||
Args:
|
||||
target: the target binary to inspect.
|
||||
nm_command: the command to run nm.
|
||||
max_depth: the maximum depth to create the trie.
|
||||
"""
|
||||
# Run nm to get a dump on the strings.
|
||||
proc = subprocess.Popen(
|
||||
[nm_command, '--radix=d', '--size-sort', '--print-size', target],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
nm_out, _ = proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
print('NM command failed. Output is as follows:')
|
||||
print(nm_out)
|
||||
sys.exit(1)
|
||||
# Run c++filt to get proper symbols.
|
||||
proc = subprocess.Popen(['c++filt'],
|
||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT)
|
||||
out, _ = proc.communicate(input=nm_out)
|
||||
if proc.returncode != 0:
|
||||
print('c++filt failed. Output is as follows:')
|
||||
print(out)
|
||||
sys.exit(1)
|
||||
# Splits the output to size and function name.
|
||||
data = []
|
||||
for line in out.split('\n'):
|
||||
if line:
|
||||
content = line.split(' ')
|
||||
if len(content) < 4:
|
||||
# This is a line not representing symbol sizes. skip.
|
||||
continue
|
||||
data.append([int(content[1]), ' '.join(content[3:])])
|
||||
symbol_trie = Trie('')
|
||||
for size, name in data:
|
||||
curr = symbol_trie
|
||||
for c in name:
|
||||
if c not in curr.dictionary:
|
||||
curr.dictionary[c] = Trie(curr.name + c)
|
||||
curr = curr.dictionary[c]
|
||||
curr.size += size
|
||||
if len(curr.name) > max_depth:
|
||||
break
|
||||
symbol_trie.size = sum(t.size for t in symbol_trie.dictionary.values())
|
||||
return symbol_trie
|
||||
|
||||
|
||||
def MaybeAddColor(s, color):
|
||||
"""Wrap the input string to the xterm green color, if color is set.
|
||||
"""
|
||||
if color:
|
||||
return '\033[92m{0}\033[0m'.format(s)
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
def ReadableSize(num):
|
||||
"""Get a human-readable size."""
|
||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||
if abs(num) <= 1024.0:
|
||||
return '%3.2f%s' % (num, unit)
|
||||
num /= 1024.0
|
||||
return '%.1f TB' % (num,)
|
||||
|
||||
|
||||
# Note(jiayq): I know, I know, this is a recursive function, but it is
|
||||
# convenient to write.
|
||||
def PrintTrie(trie, prefix, max_depth, min_size, color):
|
||||
"""Prints the symbol trie in a readable manner.
|
||||
"""
|
||||
if len(trie.name) == max_depth or not trie.dictionary.keys():
|
||||
# If we are reaching a leaf node or the maximum depth, we will print the
|
||||
# result.
|
||||
if trie.size > min_size:
|
||||
print('{0}{1} {2}'.format(
|
||||
prefix,
|
||||
MaybeAddColor(trie.name, color),
|
||||
ReadableSize(trie.size)))
|
||||
elif len(trie.dictionary.keys()) == 1:
|
||||
# There is only one child in this dictionary, so we will just delegate
|
||||
# to the downstream trie to print stuff.
|
||||
PrintTrie(
|
||||
trie.dictionary.values()[0], prefix, max_depth, min_size, color)
|
||||
elif trie.size > min_size:
|
||||
print('{0}{1} {2}'.format(
|
||||
prefix,
|
||||
MaybeAddColor(trie.name, color),
|
||||
ReadableSize(trie.size)))
|
||||
keys_with_sizes = [
|
||||
(k, trie.dictionary[k].size) for k in trie.dictionary.keys()]
|
||||
keys_with_sizes.sort(key=lambda x: x[1])
|
||||
for k, _ in keys_with_sizes[::-1]:
|
||||
PrintTrie(
|
||||
trie.dictionary[k], prefix + ' |', max_depth, min_size, color)
|
||||
|
||||
|
||||
def main(argv):
|
||||
if not sys.platform.startswith('linux'):
|
||||
raise RuntimeError('Currently this tool only supports Linux.')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool to inspect binary size.")
|
||||
parser.add_argument(
|
||||
'--max_depth', type=int, default=10,
|
||||
help='The maximum depth to print the symbol tree.')
|
||||
parser.add_argument(
|
||||
'--min_size', type=int, default=1024,
|
||||
help='The mininum symbol size to print.')
|
||||
parser.add_argument(
|
||||
'--nm_command', type=str, default='nm',
|
||||
help='The path to the nm command that the tool needs.')
|
||||
parser.add_argument(
|
||||
'--color', action='store_true',
|
||||
help='If set, use ascii color for output.')
|
||||
parser.add_argument(
|
||||
'--target', type=str,
|
||||
help='The binary target to inspect.')
|
||||
args = parser.parse_args(argv)
|
||||
if not args.target:
|
||||
raise RuntimeError('You must specify a target to inspect.')
|
||||
symbol_trie = GetSymbolTrie(
|
||||
args.target, args.nm_command, args.max_depth)
|
||||
PrintTrie(symbol_trie, '', args.max_depth, args.min_size, args.color)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
139
caffe2/python/brew.py
Normal file
139
caffe2/python/brew.py
Normal file
@ -0,0 +1,139 @@
|
||||
## @package model_helper_api
|
||||
# Module caffe2.python.model_helper_api
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import sys
|
||||
import copy
|
||||
import inspect
|
||||
from past.builtins import basestring
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
|
||||
# flake8: noqa
|
||||
from caffe2.python.helpers.algebra import *
|
||||
from caffe2.python.helpers.arg_scope import *
|
||||
from caffe2.python.helpers.array_helpers import *
|
||||
from caffe2.python.helpers.control_ops import *
|
||||
from caffe2.python.helpers.conv import *
|
||||
from caffe2.python.helpers.db_input import *
|
||||
from caffe2.python.helpers.dropout import *
|
||||
from caffe2.python.helpers.elementwise_linear import *
|
||||
from caffe2.python.helpers.fc import *
|
||||
from caffe2.python.helpers.nonlinearity import *
|
||||
from caffe2.python.helpers.normalization import *
|
||||
from caffe2.python.helpers.pooling import *
|
||||
from caffe2.python.helpers.quantization import *
|
||||
from caffe2.python.helpers.tools import *
|
||||
from caffe2.python.helpers.train import *
|
||||
|
||||
|
||||
class HelperWrapper(object):
|
||||
_registry = {
|
||||
'arg_scope': arg_scope,
|
||||
'fc': fc,
|
||||
'packed_fc': packed_fc,
|
||||
'fc_decomp': fc_decomp,
|
||||
'fc_sparse': fc_sparse,
|
||||
'fc_prune': fc_prune,
|
||||
'dropout': dropout,
|
||||
'max_pool': max_pool,
|
||||
'average_pool': average_pool,
|
||||
'max_pool_with_index' : max_pool_with_index,
|
||||
'lrn': lrn,
|
||||
'softmax': softmax,
|
||||
'instance_norm': instance_norm,
|
||||
'spatial_bn': spatial_bn,
|
||||
'spatial_gn': spatial_gn,
|
||||
'moments_with_running_stats': moments_with_running_stats,
|
||||
'relu': relu,
|
||||
'prelu': prelu,
|
||||
'tanh': tanh,
|
||||
'concat': concat,
|
||||
'depth_concat': depth_concat,
|
||||
'sum': sum,
|
||||
'reduce_sum': reduce_sum,
|
||||
'sub': sub,
|
||||
'arg_min': arg_min,
|
||||
'transpose': transpose,
|
||||
'iter': iter,
|
||||
'accuracy': accuracy,
|
||||
'conv': conv,
|
||||
'conv_nd': conv_nd,
|
||||
'conv_transpose': conv_transpose,
|
||||
'group_conv': group_conv,
|
||||
'group_conv_deprecated': group_conv_deprecated,
|
||||
'image_input': image_input,
|
||||
'video_input': video_input,
|
||||
'add_weight_decay': add_weight_decay,
|
||||
'elementwise_linear': elementwise_linear,
|
||||
'layer_norm': layer_norm,
|
||||
'mat_mul' : mat_mul,
|
||||
'batch_mat_mul' : batch_mat_mul,
|
||||
'cond' : cond,
|
||||
'loop' : loop,
|
||||
'db_input' : db_input,
|
||||
'fused_8bit_rowwise_quantized_to_float' : fused_8bit_rowwise_quantized_to_float,
|
||||
'sparse_lengths_sum_4bit_rowwise_sparse': sparse_lengths_sum_4bit_rowwise_sparse,
|
||||
}
|
||||
|
||||
def __init__(self, wrapped):
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __getattr__(self, helper_name):
|
||||
if helper_name not in self._registry:
|
||||
raise AttributeError(
|
||||
"Helper function {} not "
|
||||
"registered.".format(helper_name)
|
||||
)
|
||||
|
||||
def scope_wrapper(*args, **kwargs):
|
||||
new_kwargs = {}
|
||||
if helper_name != 'arg_scope':
|
||||
if len(args) > 0 and isinstance(args[0], ModelHelper):
|
||||
model = args[0]
|
||||
elif 'model' in kwargs:
|
||||
model = kwargs['model']
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"The first input of helper function should be model. " \
|
||||
"Or you can provide it in kwargs as model=<your_model>.")
|
||||
new_kwargs = copy.deepcopy(model.arg_scope)
|
||||
func = self._registry[helper_name]
|
||||
var_names, _, varkw, _= inspect.getargspec(func)
|
||||
if varkw is None:
|
||||
# this helper function does not take in random **kwargs
|
||||
new_kwargs = {
|
||||
var_name: new_kwargs[var_name]
|
||||
for var_name in var_names if var_name in new_kwargs
|
||||
}
|
||||
|
||||
cur_scope = get_current_scope()
|
||||
new_kwargs.update(cur_scope.get(helper_name, {}))
|
||||
new_kwargs.update(kwargs)
|
||||
return func(*args, **new_kwargs)
|
||||
|
||||
scope_wrapper.__name__ = helper_name
|
||||
return scope_wrapper
|
||||
|
||||
def Register(self, helper):
|
||||
name = helper.__name__
|
||||
if name in self._registry:
|
||||
raise AttributeError(
|
||||
"Helper {} already exists. Please change your "
|
||||
"helper name.".format(name)
|
||||
)
|
||||
self._registry[name] = helper
|
||||
|
||||
def has_helper(self, helper_or_helper_name):
|
||||
helper_name = (
|
||||
helper_or_helper_name
|
||||
if isinstance(helper_or_helper_name, basestring) else
|
||||
helper_or_helper_name.__name__
|
||||
)
|
||||
return helper_name in self._registry
|
||||
|
||||
|
||||
# pyre-fixme[6]: incompatible parameter type: expected ModuleType, got HelperWrapper
|
||||
sys.modules[__name__] = HelperWrapper(sys.modules[__name__])
|
||||
328
caffe2/python/brew_test.py
Normal file
328
caffe2/python/brew_test.py
Normal file
@ -0,0 +1,328 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import brew, core, scope, workspace
|
||||
from caffe2.python.modeling.parameter_info import ParameterTags
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
from caffe2.python.cnn import CNNModelHelper
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BrewTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
def myhelper(model, val=-1):
|
||||
return val
|
||||
|
||||
if not brew.has_helper(myhelper):
|
||||
brew.Register(myhelper)
|
||||
self.myhelper = myhelper
|
||||
|
||||
def myhelper2(model, val=-1):
|
||||
return val
|
||||
|
||||
if not brew.has_helper(myhelper2):
|
||||
brew.Register(myhelper2)
|
||||
self.myhelper2 = myhelper2
|
||||
self.model = ModelHelper(name="test_model")
|
||||
|
||||
def test_dropout(self):
|
||||
p = 0.2
|
||||
X = np.ones((100, 100)).astype(np.float32) - p
|
||||
workspace.FeedBlob("x", X)
|
||||
model = ModelHelper(name="test_model")
|
||||
brew.dropout(model, "x", "out", is_test=False)
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
out = workspace.FetchBlob("out")
|
||||
self.assertLess(abs(out.mean() - (1 - p)), 0.05)
|
||||
|
||||
def test_fc(self):
|
||||
m, n, k = (15, 15, 15)
|
||||
X = np.random.rand(m, k).astype(np.float32) - 0.5
|
||||
|
||||
workspace.FeedBlob("x", X)
|
||||
model = ModelHelper(name="test_model")
|
||||
brew.fc(model, "x", "out_1", k, n)
|
||||
model.Validate()
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
|
||||
def test_relu(self):
|
||||
Xpos = np.ones((5, 5)).astype(np.float32) - 0.5
|
||||
Xneg = np.ones((5, 5)).astype(np.float32) - 1.5
|
||||
|
||||
workspace.FeedBlob("xpos", Xpos)
|
||||
workspace.FeedBlob("xneg", Xneg)
|
||||
model = ModelHelper(name="test_model")
|
||||
brew.relu(model, "xpos", "out_xpos")
|
||||
brew.relu(model, "xneg", "out_xneg")
|
||||
model.Validate()
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
|
||||
pos = workspace.FetchBlob("out_xpos")
|
||||
self.assertAlmostEqual(pos.mean(), 0.5)
|
||||
neg = workspace.FetchBlob("out_xneg")
|
||||
self.assertAlmostEqual(neg.mean(), 0)
|
||||
|
||||
def test_tanh(self):
|
||||
X = np.ones((5, 5)).astype(np.float32) - 0.5
|
||||
|
||||
workspace.FeedBlob("x", X)
|
||||
model = ModelHelper(name="test_model")
|
||||
brew.tanh(model, "x", "out_tanh")
|
||||
model.Validate()
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
|
||||
out = workspace.FetchBlob("out_tanh")
|
||||
self.assertAlmostEqual(out.mean(), np.tanh(0.5), places=5)
|
||||
|
||||
def test_validate(self):
|
||||
model = ModelHelper(name="test_model")
|
||||
model.params.append("aaa")
|
||||
model.params.append("bbb")
|
||||
self.assertEqual(model._Validate(), [])
|
||||
|
||||
model.params.append("xxx")
|
||||
model.params.append("bbb")
|
||||
self.assertEqual(model._Validate(), ["bbb"])
|
||||
|
||||
def test_arg_scope(self):
|
||||
myhelper = self.myhelper
|
||||
myhelper2 = self.myhelper2
|
||||
n = 15
|
||||
with brew.arg_scope([myhelper], val=n):
|
||||
res = brew.myhelper(self.model)
|
||||
self.assertEqual(n, res)
|
||||
|
||||
with brew.arg_scope([myhelper, myhelper2], val=n):
|
||||
res1 = brew.myhelper(self.model)
|
||||
res2 = brew.myhelper2(self.model)
|
||||
self.assertEqual([n, n], [res1, res2])
|
||||
|
||||
def test_arg_scope_single(self):
|
||||
X = np.random.rand(64, 3, 32, 32).astype(np.float32) - 0.5
|
||||
|
||||
workspace.FeedBlob("x", X)
|
||||
model = ModelHelper(name="test_model")
|
||||
with brew.arg_scope(
|
||||
brew.conv,
|
||||
stride=2,
|
||||
pad=2,
|
||||
weight_init=('XavierFill', {}),
|
||||
bias_init=('ConstantFill', {})
|
||||
):
|
||||
brew.conv(
|
||||
model=model,
|
||||
blob_in="x",
|
||||
blob_out="out",
|
||||
dim_in=3,
|
||||
dim_out=64,
|
||||
kernel=3,
|
||||
)
|
||||
model.Validate()
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
out = workspace.FetchBlob("out")
|
||||
self.assertEqual(out.shape, (64, 64, 17, 17))
|
||||
|
||||
def test_arg_scope_nested(self):
|
||||
myhelper = self.myhelper
|
||||
n = 16
|
||||
with brew.arg_scope([myhelper], val=-3), \
|
||||
brew.arg_scope([myhelper], val=-2):
|
||||
with brew.arg_scope([myhelper], val=n):
|
||||
res = brew.myhelper(self.model)
|
||||
self.assertEqual(n, res)
|
||||
res = brew.myhelper(self.model)
|
||||
self.assertEqual(res, -2)
|
||||
|
||||
res = brew.myhelper(self.model, val=15)
|
||||
self.model.Validate()
|
||||
self.assertEqual(res, 15)
|
||||
|
||||
def test_double_register(self):
|
||||
myhelper = self.myhelper
|
||||
with self.assertRaises(AttributeError):
|
||||
brew.Register(myhelper)
|
||||
|
||||
def test_has_helper(self):
|
||||
self.assertTrue(brew.has_helper(brew.conv))
|
||||
self.assertTrue(brew.has_helper("conv"))
|
||||
|
||||
def myhelper3():
|
||||
pass
|
||||
|
||||
self.assertFalse(brew.has_helper(myhelper3))
|
||||
|
||||
def test_model_helper(self):
|
||||
X = np.random.rand(64, 32, 32, 3).astype(np.float32) - 0.5
|
||||
|
||||
workspace.FeedBlob("x", X)
|
||||
my_arg_scope = {'order': 'NHWC'}
|
||||
model = ModelHelper(name="test_model", arg_scope=my_arg_scope)
|
||||
with brew.arg_scope(
|
||||
brew.conv,
|
||||
stride=2,
|
||||
pad=2,
|
||||
weight_init=('XavierFill', {}),
|
||||
bias_init=('ConstantFill', {})
|
||||
):
|
||||
brew.conv(
|
||||
model=model,
|
||||
blob_in="x",
|
||||
blob_out="out",
|
||||
dim_in=3,
|
||||
dim_out=64,
|
||||
kernel=[8, 3]
|
||||
)
|
||||
model.Validate()
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
out = workspace.FetchBlob("out")
|
||||
self.assertEqual(out.shape, (64, 15, 17, 64))
|
||||
|
||||
def test_cnn_model_helper_deprecated(self):
|
||||
X = np.random.rand(64, 32, 32, 3).astype(np.float32) - 0.5
|
||||
|
||||
workspace.FeedBlob("x", X)
|
||||
# CNNModelHelper is going to be deprecated soon. This test is only
|
||||
# covering some CNNModelHelper logic
|
||||
model = CNNModelHelper(name="test_model", order='NHWC')
|
||||
self.assertEqual(model.arg_scope['order'], 'NHWC')
|
||||
|
||||
def test_get_params(self):
|
||||
def param(x):
|
||||
return core.ScopedBlobReference(x)
|
||||
|
||||
def to_str_list(x):
|
||||
return sorted([str(p) for p in x])
|
||||
|
||||
model = ModelHelper(name="test_model")
|
||||
model.AddParameter(param("a"))
|
||||
model.AddParameter(param("b"), tags=ParameterTags.COMPUTED_PARAM)
|
||||
with scope.NameScope("c"):
|
||||
model.AddParameter(param("a"))
|
||||
model.AddParameter(param("d"), tags=ParameterTags.COMPUTED_PARAM)
|
||||
self.assertEqual(to_str_list(model.GetParams()), ['c/a'])
|
||||
self.assertEqual(to_str_list(model.GetComputedParams()), ['c/d'])
|
||||
self.assertEqual(to_str_list(model.GetAllParams()), ['c/a', 'c/d'])
|
||||
# Get AllParams from the global Scope
|
||||
self.assertEqual(to_str_list(model.GetAllParams('')), [
|
||||
'a', 'b', 'c/a', 'c/d'])
|
||||
self.assertEqual(to_str_list(model.GetParams()), ['a', 'c/a'])
|
||||
self.assertEqual(to_str_list(model.GetComputedParams()), ['b', 'c/d'])
|
||||
self.assertEqual(to_str_list(model.GetAllParams()),
|
||||
['a', 'b', 'c/a', 'c/d'])
|
||||
self.assertEqual(to_str_list(model.GetAllParams('')),
|
||||
['a', 'b', 'c/a', 'c/d'])
|
||||
# Get AllParams from the scope 'c'
|
||||
self.assertEqual(to_str_list(model.GetAllParams('c')), ['c/a', 'c/d'])
|
||||
self.assertEqual(to_str_list(model.GetAllParams('c/')), ['c/a', 'c/d'])
|
||||
|
||||
def test_param_consistence(self):
|
||||
model = ModelHelper(name='test_mode')
|
||||
cnv = brew.conv(model, 'data', 'cnv', 32, 32, 4)
|
||||
step_model = ModelHelper(name='step_model', param_model=model)
|
||||
a = brew.fc(step_model, cnv, 'a', 100, 200)
|
||||
brew.fc(model, a, 'b', 200, 5)
|
||||
# test the _parameters_info is shared between model and step_model
|
||||
self.assertEqual(model._parameters_info, step_model._parameters_info)
|
||||
|
||||
def test_cond(self):
|
||||
workspace.FeedBlob("cond", np.array(True))
|
||||
workspace.FeedBlob("then_value", np.array(1))
|
||||
workspace.FeedBlob("else_value", np.array(2))
|
||||
|
||||
then_model = ModelHelper(name="then_test_model")
|
||||
then_model.net.Copy("then_value", "output_blob")
|
||||
|
||||
else_model = ModelHelper(name="else_test_model")
|
||||
else_model.net.Copy("else_value", "output_blob")
|
||||
|
||||
model = ModelHelper(name="test_model")
|
||||
brew.cond(
|
||||
model=model,
|
||||
cond_blob="cond",
|
||||
external_blobs=["then_value", "else_value", "output_blob"],
|
||||
then_model=then_model,
|
||||
else_model=else_model)
|
||||
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
output_value = workspace.FetchBlob("output_blob")
|
||||
self.assertEqual(output_value, 1)
|
||||
workspace.FeedBlob("cond", np.array(False))
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
output_value = workspace.FetchBlob("output_blob")
|
||||
self.assertEqual(output_value, 2)
|
||||
|
||||
def test_loop(self):
|
||||
workspace.FeedBlob("cond", np.array(True))
|
||||
workspace.FeedBlob("ONE", np.array(1))
|
||||
workspace.FeedBlob("TWO", np.array(2))
|
||||
workspace.FeedBlob("TEN", np.array(10))
|
||||
workspace.FeedBlob("counter", np.array(0))
|
||||
workspace.FeedBlob("output_blob", np.array(0))
|
||||
|
||||
loop_model = ModelHelper(name="loop_test_model")
|
||||
loop_model.net.Add(["output_blob", "TWO"], "output_blob")
|
||||
|
||||
cond_model = ModelHelper(name="cond_test_model")
|
||||
cond_model.net.Add(["counter", "ONE"], "counter")
|
||||
comp_res = cond_model.net.LT(["counter", "TEN"])
|
||||
cond_model.net.Copy(comp_res, "cond")
|
||||
|
||||
model = ModelHelper(name="test_model")
|
||||
brew.loop(
|
||||
model=model,
|
||||
cond_blob="cond",
|
||||
external_blobs=["cond", "ONE", "TWO", "TEN", "counter", "output_blob"],
|
||||
loop_model=loop_model,
|
||||
cond_model=cond_model)
|
||||
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
output_value = workspace.FetchBlob("output_blob")
|
||||
self.assertEqual(output_value, 18)
|
||||
|
||||
|
||||
@unittest.skipIf(not workspace.has_gpu_support, "No gpu support.")
|
||||
class BrewGPUTest(unittest.TestCase):
|
||||
def test_relu(self):
|
||||
Xpos = np.ones((5, 5)).astype(np.float32) - 0.5
|
||||
Xneg = np.ones((5, 5)).astype(np.float32) - 1.5
|
||||
|
||||
workspace.FeedBlob("xpos", Xpos)
|
||||
workspace.FeedBlob("xneg", Xneg)
|
||||
model = ModelHelper(name="test_model")
|
||||
brew.relu(model, "xpos", "out_xpos", use_cudnn=True)
|
||||
brew.relu(model, "xneg", "out_xneg", use_cudnn=True)
|
||||
model.Validate()
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
|
||||
pos = workspace.FetchBlob("out_xpos")
|
||||
self.assertAlmostEqual(pos.mean(), 0.5)
|
||||
neg = workspace.FetchBlob("out_xneg")
|
||||
self.assertAlmostEqual(neg.mean(), 0)
|
||||
|
||||
def test_tanh(self):
|
||||
X = np.ones((5, 5)).astype(np.float32) - 0.5
|
||||
|
||||
workspace.FeedBlob("x", X)
|
||||
model = ModelHelper(name="test_model")
|
||||
brew.tanh(model, "x", "out_tanh", use_cudnn=True)
|
||||
model.Validate()
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
|
||||
out = workspace.FetchBlob("out_tanh")
|
||||
self.assertAlmostEqual(out.mean(), np.tanh(0.5), places=5)
|
||||
9
caffe2/python/build.py
Normal file
9
caffe2/python/build.py
Normal file
@ -0,0 +1,9 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import caffe2.python._import_c_extension as C
|
||||
|
||||
CAFFE2_NO_OPERATOR_SCHEMA = C.define_caffe2_no_operator_schema
|
||||
build_options = C.get_build_options()
|
||||
133
caffe2/python/cached_reader.py
Normal file
133
caffe2/python/cached_reader.py
Normal file
@ -0,0 +1,133 @@
|
||||
## @package cached_reader
|
||||
# Module caffe2.python.cached_reader
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import os
|
||||
|
||||
from caffe2.python import core
|
||||
from caffe2.python.db_file_reader import DBFileReader
|
||||
from caffe2.python.pipeline import pipe
|
||||
from caffe2.python.task import Cluster, TaskGroup
|
||||
|
||||
|
||||
class CachedReader(DBFileReader):
|
||||
|
||||
default_name_suffix = 'cached_reader'
|
||||
|
||||
"""Reader with persistent in-file cache.
|
||||
|
||||
Example usage:
|
||||
cached_reader = CachedReader(
|
||||
reader,
|
||||
db_path='/tmp/cache.db',
|
||||
db_type='LevelDB',
|
||||
)
|
||||
build_cache_step = cached_reader.build_cache_step()
|
||||
with LocalSession() as session:
|
||||
session.run(build_cache_step)
|
||||
|
||||
Every time new CachedReader is created, it's expected that
|
||||
db_path exists before calling .setup_ex(...) and .read(...).
|
||||
|
||||
If db_path doesn't exist, it's expected build_cache_step to be called
|
||||
first to build a cache at db_path.
|
||||
|
||||
build_cache_step will check existence of provided db_path and in case
|
||||
it's missing will initialize it by reading data from original reader.
|
||||
All consequent attempts to read will ignore original reader
|
||||
(i.e. no additional data will be read from it).
|
||||
|
||||
Args:
|
||||
original_reader: Reader.
|
||||
If provided, it's the original reader used to build the cache file.
|
||||
db_path: str.
|
||||
|
||||
Optional Args:
|
||||
db_type: str. DB type of file. A db_type is registed by
|
||||
`REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
|
||||
Default to 'LevelDB'.
|
||||
name: str or None. Name of CachedReader.
|
||||
Optional name to prepend to blobs that will store the data.
|
||||
Default to '<db_name>_<default_name_suffix>'.
|
||||
batch_size: int.
|
||||
How many examples are read for each time the read_net is run.
|
||||
Defaults to 100.
|
||||
loop_over: bool.
|
||||
If True given, will go through examples in random order endlessly.
|
||||
Defaults to False.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
original_reader,
|
||||
db_path,
|
||||
db_type='LevelDB',
|
||||
name=None,
|
||||
batch_size=100,
|
||||
loop_over=False,
|
||||
):
|
||||
assert original_reader is not None, "original_reader can't be None"
|
||||
self.original_reader = original_reader
|
||||
|
||||
super().__init__(
|
||||
db_path,
|
||||
db_type,
|
||||
name,
|
||||
batch_size,
|
||||
loop_over,
|
||||
)
|
||||
|
||||
def _init_reader_schema(self, *args, **kwargs):
|
||||
"""Prepare the reader schema.
|
||||
|
||||
Since an original reader is given,
|
||||
use it's schema as ground truth.
|
||||
|
||||
Returns:
|
||||
schema: schema.Struct. Used in Reader.__init__(...).
|
||||
"""
|
||||
return self.original_reader._schema
|
||||
|
||||
def build_cache_step(self, overwrite=False):
|
||||
"""Build a step for generating cache DB file.
|
||||
|
||||
If self.db_path exists and not overwritting, build an empty step.
|
||||
Overwise, build a step as follows.
|
||||
Pipe original reader to the _DatasetWriter,
|
||||
so that dataset field blobs are populated.
|
||||
Then save these blobs into a file.
|
||||
|
||||
Args:
|
||||
overwrite: bool. If true, ignore the existing file
|
||||
and build a new one overwritting the existing one anyway.
|
||||
|
||||
Returns:
|
||||
build_cache_step: ExecutionStep.
|
||||
The step to be run for building a cache DB file.
|
||||
"""
|
||||
if os.path.exists(self.db_path) and not overwrite:
|
||||
# cache already exists, no need to rebuild it
|
||||
return core.execution_step('build_step', [])
|
||||
|
||||
init_net = core.Net('init')
|
||||
self._init_field_blobs_as_empty(init_net)
|
||||
with Cluster(), core.NameScope(self.name), TaskGroup() as copy_tg:
|
||||
pipe(self.original_reader, self.ds.writer(), num_threads=16)
|
||||
copy_step = copy_tg.to_task().get_step()
|
||||
save_net = core.Net('save')
|
||||
self._save_field_blobs_to_db_file(save_net)
|
||||
|
||||
return core.execution_step('build_cache', [init_net, copy_step, save_net])
|
||||
|
||||
def _save_field_blobs_to_db_file(self, net):
|
||||
"""Save dataset field blobs to a DB file at db_path"""
|
||||
net.Save(
|
||||
self.ds.get_blobs(),
|
||||
[],
|
||||
db=self.db_path,
|
||||
db_type=self.db_type,
|
||||
blob_name_overrides=self.ds.field_names(),
|
||||
absolute_path=True,
|
||||
)
|
||||
937
caffe2/python/caffe_translator.py
Normal file
937
caffe2/python/caffe_translator.py
Normal file
@ -0,0 +1,937 @@
|
||||
## @package caffe_translator
|
||||
# Module caffe2.python.caffe_translator
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
import numpy as np # noqa
|
||||
|
||||
from caffe2.proto import caffe2_pb2, caffe2_legacy_pb2
|
||||
from caffe.proto import caffe_pb2
|
||||
from caffe2.python import core, utils, workspace
|
||||
from google.protobuf import text_format
|
||||
|
||||
logging.basicConfig()
|
||||
log = logging.getLogger("caffe_translator")
|
||||
log.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def _StateMeetsRule(state, rule):
|
||||
"""A function that reproduces Caffe's StateMeetsRule functionality."""
|
||||
if rule.HasField('phase') and rule.phase != state.phase:
|
||||
return False
|
||||
if rule.HasField('min_level') and state.level < rule.min_level:
|
||||
return False
|
||||
if rule.HasField('max_level') and state.level > rule.max_level:
|
||||
return False
|
||||
curr_stages = set(list(state.stage))
|
||||
# all stages in rule.stages should be in, otherwise it's not a match.
|
||||
if len(rule.stage) and any([s not in curr_stages for s in rule.stage]):
|
||||
return False
|
||||
# none of the stage in rule.stages should be in, otherwise it's not a match.
|
||||
if len(rule.not_stage) and any([s in curr_stages for s in rule.not_stage]):
|
||||
return False
|
||||
# If none of the nonmatch happens, return True.
|
||||
return True
|
||||
|
||||
|
||||
def _ShouldInclude(net_state, layer):
|
||||
"""A function that reproduces Caffe's inclusion and exclusion rule."""
|
||||
ret = (len(layer.include) == 0)
|
||||
# check exclude rules: if any exclusion is met, we shouldn't include.
|
||||
ret &= not any([_StateMeetsRule(net_state, rule) for rule in layer.exclude])
|
||||
if len(layer.include):
|
||||
# check include rules: if any inclusion is met, we should include.
|
||||
ret |= any([_StateMeetsRule(net_state, rule) for rule in layer.include])
|
||||
return ret
|
||||
|
||||
|
||||
def _GetLegacyDims(net, net_params, dummy_input, legacy_pad_ops):
|
||||
dim_map = {}
|
||||
ws = workspace.C.Workspace()
|
||||
for param in net_params.protos:
|
||||
ws.create_blob(param.name) \
|
||||
.feed(utils.Caffe2TensorToNumpyArray(param))
|
||||
external_input = net.op[0].input[0]
|
||||
ws.create_blob(external_input).feed(dummy_input)
|
||||
# Get dimensions with legacy pad
|
||||
for i in range(len(net.op)):
|
||||
op_def = net.op[i]
|
||||
ws._run_operator(op_def.SerializeToString())
|
||||
if i in legacy_pad_ops:
|
||||
output = op_def.output[0]
|
||||
blob_legacy = ws.fetch_blob(output)
|
||||
dim_map[i] = blob_legacy.shape
|
||||
return dim_map
|
||||
|
||||
|
||||
def _GetLegacyPadArgs(op_def, arg_map):
|
||||
pads = {}
|
||||
keys = ['pad_l', 'pad_t', 'pad_r', 'pad_b']
|
||||
is_pad = 'pad' in arg_map
|
||||
if is_pad:
|
||||
for k in keys:
|
||||
pads[k] = arg_map['pad'].i
|
||||
else:
|
||||
pads = {x: arg_map[x].i for x in keys}
|
||||
return pads
|
||||
|
||||
|
||||
def _AdjustDims(op_def, arg_map, pads, dim1, dim2):
|
||||
n1, c1, h1, w1 = dim1
|
||||
n2, c2, h2, w2 = dim2
|
||||
assert(n1 == n2)
|
||||
assert(c1 == c2)
|
||||
is_pad = 'pad' in arg_map
|
||||
if h1 != h2 or w1 != w2:
|
||||
if h1 == h2 + 1:
|
||||
pads['pad_b'] += 1
|
||||
elif h1 != h2:
|
||||
raise Exception("Unexpected dimensions for height:", h1, h2)
|
||||
if w1 == w2 + 1:
|
||||
pads['pad_r'] += 1
|
||||
elif w1 != w2:
|
||||
raise Exception("Unexpected dimensions for width:", w1, w2)
|
||||
if is_pad:
|
||||
op_def.arg.remove(arg_map['pad'])
|
||||
args = []
|
||||
for name in pads.keys():
|
||||
arg = caffe2_pb2.Argument()
|
||||
arg.name = name
|
||||
arg.i = pads[name]
|
||||
args.append(arg)
|
||||
op_def.arg.extend(args)
|
||||
else:
|
||||
for name in pads.keys():
|
||||
arg_map[name].i = pads[name]
|
||||
|
||||
|
||||
def _RemoveLegacyPad(net, net_params, input_dims):
|
||||
legacy_pad_ops = []
|
||||
for i in range(len(net.op)):
|
||||
op_def = net.op[i]
|
||||
if re.match(r'^(Conv|ConvTranspose|MaxPool|AveragePool)(\dD)?$',
|
||||
op_def.type):
|
||||
for arg in op_def.arg:
|
||||
if arg.name == 'legacy_pad':
|
||||
legacy_pad_ops.append(i)
|
||||
break
|
||||
if legacy_pad_ops:
|
||||
n, c, h, w = input_dims
|
||||
dummy_input = np.random.randn(n, c, h, w).astype(np.float32)
|
||||
dim_map = _GetLegacyDims(net, net_params, dummy_input, legacy_pad_ops)
|
||||
|
||||
# Running with the legacy pad argument removed
|
||||
# compare the dimensions and adjust pad argument when necessary
|
||||
ws = workspace.C.Workspace()
|
||||
|
||||
external_input = net.op[0].input[0]
|
||||
ws.create_blob(external_input).feed_blob(dummy_input)
|
||||
for param in net_params.protos:
|
||||
ws.create_blob(param.name) \
|
||||
.feed_blob(utils.Caffe2TensorToNumpyArray(param))
|
||||
|
||||
for i in range(len(net.op)):
|
||||
op_def = net.op[i]
|
||||
if i in legacy_pad_ops:
|
||||
arg_map = {}
|
||||
for arg in op_def.arg:
|
||||
arg_map[arg.name] = arg
|
||||
pads = _GetLegacyPadArgs(op_def, arg_map)
|
||||
# remove legacy pad arg
|
||||
for j in range(len(op_def.arg)):
|
||||
arg = op_def.arg[j]
|
||||
if arg.name == 'legacy_pad':
|
||||
del op_def.arg[j]
|
||||
break
|
||||
output = op_def.output[0]
|
||||
# use a new name to avoid the interference with inplace
|
||||
nonlegacy_output = output + '_nonlegacy'
|
||||
op_def.output[0] = nonlegacy_output
|
||||
ws._run_operator(op_def.SerializeToString())
|
||||
blob_nonlegacy = ws.fetch_blob(nonlegacy_output)
|
||||
# reset output name
|
||||
op_def.output[0] = output
|
||||
|
||||
dim1 = dim_map[i]
|
||||
dim2 = blob_nonlegacy.shape
|
||||
_AdjustDims(op_def, arg_map, pads, dim1, dim2)
|
||||
|
||||
ws._run_operator(op_def.SerializeToString())
|
||||
return net
|
||||
|
||||
|
||||
def _GetBlobDimMap(net, net_params, dummy_input):
|
||||
dim_map = {}
|
||||
ws = workspace.C.Workspace()
|
||||
for param in net_params.protos:
|
||||
ws.create_blob(param.name) \
|
||||
.feed(utils.Caffe2TensorToNumpyArray(param))
|
||||
external_input = net.op[0].input[0]
|
||||
ws.create_blob(external_input).feed(dummy_input)
|
||||
# Get dimensions with legacy pad
|
||||
for i in range(len(net.op)):
|
||||
op_def = net.op[i]
|
||||
ws._run_operator(op_def.SerializeToString())
|
||||
for output in op_def.output:
|
||||
blob = ws.fetch_blob(output)
|
||||
dim_map[output] = blob.shape
|
||||
return dim_map
|
||||
|
||||
|
||||
def _GetInputDims(caffe_net):
|
||||
input_dims = []
|
||||
if caffe_net.input_dim:
|
||||
input_dims = caffe_net.input_dim
|
||||
elif caffe_net.input_shape:
|
||||
input_dims = caffe_net.input_shape[0].dim
|
||||
elif caffe_net.layer[0].input_param.shape:
|
||||
# getting input dimension from first layer
|
||||
input_dims = caffe_net.layer[0].input_param.shape[0].dim
|
||||
return input_dims
|
||||
|
||||
|
||||
class TranslatorRegistry:
|
||||
registry_ = {}
|
||||
|
||||
@classmethod
|
||||
def Register(cls, op_name):
|
||||
"""A decorator for registering gradient mappings."""
|
||||
|
||||
def Wrapper(func):
|
||||
cls.registry_[op_name] = func
|
||||
return func
|
||||
|
||||
return Wrapper
|
||||
|
||||
@classmethod
|
||||
def TranslateLayer(cls, layer, pretrained_blobs, is_test, **kwargs):
|
||||
try:
|
||||
caffe_ops, params = cls.registry_[layer.type](
|
||||
layer, pretrained_blobs, is_test, **kwargs)
|
||||
except KeyError as e:
|
||||
raise KeyError('No translator registered for layer: %s yet.' %
|
||||
str(layer)) from e
|
||||
if caffe_ops is None:
|
||||
caffe_ops = []
|
||||
if type(caffe_ops) is not list:
|
||||
caffe_ops = [caffe_ops]
|
||||
return caffe_ops, params
|
||||
|
||||
@classmethod
|
||||
def TranslateModel(
|
||||
cls,
|
||||
caffe_net,
|
||||
pretrained_net,
|
||||
is_test=False,
|
||||
net_state=None,
|
||||
remove_legacy_pad=False,
|
||||
input_dims=None
|
||||
):
|
||||
net_state = caffe_pb2.NetState() if net_state is None else net_state
|
||||
net = caffe2_pb2.NetDef()
|
||||
net.name = caffe_net.name
|
||||
net_params = caffe2_pb2.TensorProtos()
|
||||
if len(caffe_net.layers) > 0:
|
||||
raise ValueError(
|
||||
'I think something is wrong. This translation script '
|
||||
'only accepts new style layers that are stored in the '
|
||||
'layer field.'
|
||||
)
|
||||
if not input_dims:
|
||||
input_dims = _GetInputDims(caffe_net)
|
||||
for layer in caffe_net.layer:
|
||||
if not _ShouldInclude(net_state, layer):
|
||||
log.info('Current net state does not need layer {}'
|
||||
.format(layer.name))
|
||||
continue
|
||||
log.info('Translate layer {}'.format(layer.name))
|
||||
# Get pretrained one
|
||||
pretrained_layers = (
|
||||
[l for l in pretrained_net.layer
|
||||
if l.name == layer.name] + [l
|
||||
for l in pretrained_net.layers
|
||||
if l.name == layer.name]
|
||||
)
|
||||
if len(pretrained_layers) > 1:
|
||||
raise ValueError(
|
||||
'huh? more than one pretrained layer of one name?')
|
||||
elif len(pretrained_layers) == 1:
|
||||
pretrained_blobs = [
|
||||
utils.CaffeBlobToNumpyArray(blob)
|
||||
for blob in pretrained_layers[0].blobs
|
||||
]
|
||||
else:
|
||||
# No pretrained layer for the given layer name. We'll just pass
|
||||
# no parameter blobs.
|
||||
# print 'No pretrained layer for layer', layer.name
|
||||
pretrained_blobs = []
|
||||
operators, params = cls.TranslateLayer(
|
||||
layer, pretrained_blobs, is_test, net=net,
|
||||
net_params=net_params, input_dims=input_dims)
|
||||
net.op.extend(operators)
|
||||
net_params.protos.extend(params)
|
||||
if remove_legacy_pad:
|
||||
assert input_dims, \
|
||||
'Please specify input_dims to remove legacy_pad'
|
||||
net = _RemoveLegacyPad(net, net_params, input_dims)
|
||||
return net, net_params
|
||||
|
||||
|
||||
def TranslateModel(*args, **kwargs):
|
||||
return TranslatorRegistry.TranslateModel(*args, **kwargs)
|
||||
|
||||
|
||||
def ConvertTensorProtosToInitNet(net_params, input_name):
|
||||
"""Takes the net_params returned from TranslateModel, and wrap it as an
|
||||
init net that contain GivenTensorFill.
|
||||
|
||||
This is a very simple feature that only works with float tensors, and is
|
||||
only intended to be used in an environment where you want a single
|
||||
initialization file - for more complex cases, use a db to store the
|
||||
parameters.
|
||||
"""
|
||||
init_net = caffe2_pb2.NetDef()
|
||||
for tensor in net_params.protos:
|
||||
if len(tensor.float_data) == 0:
|
||||
raise RuntimeError(
|
||||
"Only float tensors are supported in this util.")
|
||||
op = core.CreateOperator(
|
||||
"GivenTensorFill", [], [tensor.name],
|
||||
arg=[
|
||||
utils.MakeArgument("shape", list(tensor.dims)),
|
||||
utils.MakeArgument("values", tensor.float_data)])
|
||||
init_net.op.extend([op])
|
||||
init_net.op.extend([core.CreateOperator("ConstantFill", [], [input_name], shape=[1])])
|
||||
return init_net
|
||||
|
||||
|
||||
def BaseTranslate(layer, caffe2_type):
|
||||
"""A simple translate interface that maps the layer input and output."""
|
||||
caffe2_op = caffe2_pb2.OperatorDef()
|
||||
caffe2_op.type = caffe2_type
|
||||
caffe2_op.input.extend(layer.bottom)
|
||||
caffe2_op.output.extend(layer.top)
|
||||
return caffe2_op
|
||||
|
||||
|
||||
def AddArgument(op, key, value):
|
||||
"""Makes an argument based on the value type."""
|
||||
op.arg.extend([utils.MakeArgument(key, value)])
|
||||
|
||||
################################################################################
|
||||
# Common translators for layers.
|
||||
################################################################################
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Input")
|
||||
def TranslateInput(layer, pretrained_blobs, is_test, **kwargs):
|
||||
return [], []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("VideoData")
|
||||
def TranslateVideoData(layer, pretrained_blobs, is_test, **kwargs):
|
||||
return [], []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Data")
|
||||
def TranslateData(layer, pretrained_blobs, is_test, **kwargs):
|
||||
return [], []
|
||||
|
||||
|
||||
# A function used in convolution, pooling and deconvolution to deal with
|
||||
# conv pool specific parameters.
|
||||
def _TranslateStridePadKernelHelper(param, caffe_op):
|
||||
try:
|
||||
if (len(param.stride) > 1 or len(param.kernel_size) > 1 or
|
||||
len(param.pad) > 1):
|
||||
raise NotImplementedError(
|
||||
"Translator currently does not support non-conventional "
|
||||
"pad/kernel/stride settings."
|
||||
)
|
||||
stride = param.stride[0] if len(param.stride) else 1
|
||||
pad = param.pad[0] if len(param.pad) else 0
|
||||
kernel = param.kernel_size[0] if len(param.kernel_size) else 0
|
||||
except TypeError:
|
||||
# This catches the case of a PoolingParameter, in which case we are
|
||||
# having non-repeating pad, stride and kernel.
|
||||
stride = param.stride
|
||||
pad = param.pad
|
||||
kernel = param.kernel_size
|
||||
# Get stride
|
||||
if param.HasField("stride_h") or param.HasField("stride_w"):
|
||||
AddArgument(caffe_op, "stride_h", param.stride_h)
|
||||
AddArgument(caffe_op, "stride_w", param.stride_w)
|
||||
else:
|
||||
AddArgument(caffe_op, "stride", stride)
|
||||
# Get pad
|
||||
if param.HasField("pad_h") or param.HasField("pad_w"):
|
||||
if param.pad_h == param.pad_w:
|
||||
AddArgument(caffe_op, "pad", param.pad_h)
|
||||
else:
|
||||
AddArgument(caffe_op, "pad_t", param.pad_h)
|
||||
AddArgument(caffe_op, "pad_b", param.pad_h)
|
||||
AddArgument(caffe_op, "pad_l", param.pad_w)
|
||||
AddArgument(caffe_op, "pad_r", param.pad_w)
|
||||
else:
|
||||
AddArgument(caffe_op, "pad", pad)
|
||||
# Get kernel
|
||||
if param.HasField("kernel_h") or param.HasField("kernel_w"):
|
||||
AddArgument(caffe_op, "kernel_h", param.kernel_h)
|
||||
AddArgument(caffe_op, "kernel_w", param.kernel_w)
|
||||
else:
|
||||
AddArgument(caffe_op, "kernel", kernel)
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Convolution3D")
|
||||
def TranslateConvNd(layer, pretrained_blobs, is_test, **kwargs):
|
||||
param = layer.convolution3d_param
|
||||
caffe_op = BaseTranslate(layer, "Conv")
|
||||
output = caffe_op.output[0]
|
||||
caffe_op.input.append(output + '_w')
|
||||
|
||||
AddArgument(
|
||||
caffe_op,
|
||||
"kernels",
|
||||
[param.kernel_depth, param.kernel_size, param.kernel_size])
|
||||
AddArgument(
|
||||
caffe_op,
|
||||
"strides",
|
||||
[param.temporal_stride, param.stride, param.stride])
|
||||
temporal_pad = 0
|
||||
spatial_pad = 0
|
||||
if hasattr(param, 'temporal_pad'):
|
||||
temporal_pad = param.temporal_pad
|
||||
if hasattr(param, 'pad'):
|
||||
spatial_pad = param.pad
|
||||
AddArgument(caffe_op, "pads", [temporal_pad, spatial_pad, spatial_pad] * 2)
|
||||
|
||||
# weight
|
||||
params = [
|
||||
utils.NumpyArrayToCaffe2Tensor(pretrained_blobs[0], output + '_w')]
|
||||
# bias
|
||||
if len(pretrained_blobs) == 2:
|
||||
caffe_op.input.append(output + '_b')
|
||||
params.append(
|
||||
utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[1].flatten(), output + '_b'))
|
||||
return caffe_op, params
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Convolution")
|
||||
def TranslateConv(layer, pretrained_blobs, is_test, **kwargs):
|
||||
param = layer.convolution_param
|
||||
caffe_op = BaseTranslate(layer, "Conv")
|
||||
output = caffe_op.output[0]
|
||||
caffe_op.input.append(output + '_w')
|
||||
_TranslateStridePadKernelHelper(param, caffe_op)
|
||||
# weight
|
||||
params = [
|
||||
utils.NumpyArrayToCaffe2Tensor(pretrained_blobs[0], output + '_w')]
|
||||
# bias
|
||||
if len(pretrained_blobs) == 2:
|
||||
caffe_op.input.append(output + '_b')
|
||||
params.append(
|
||||
utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[1].flatten(), output + '_b'))
|
||||
# Group convolution option
|
||||
if param.group != 1:
|
||||
AddArgument(caffe_op, "group", param.group)
|
||||
# Get dilation - not tested. If you have a model and this checks out,
|
||||
# please provide a test and uncomment this.
|
||||
if len(param.dilation) > 0:
|
||||
if len(param.dilation) == 1:
|
||||
AddArgument(caffe_op, "dilation", param.dilation[0])
|
||||
elif len(param.dilation) == 2:
|
||||
AddArgument(caffe_op, "dilation_h", param.dilation[0])
|
||||
AddArgument(caffe_op, "dilation_w", param.dilation[1])
|
||||
return caffe_op, params
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Deconvolution")
|
||||
def TranslateDeconv(layer, pretrained_blobs, is_test, **kwargs):
|
||||
param = layer.convolution_param
|
||||
if param.group > 1:
|
||||
raise NotImplementedError(
|
||||
"Translator currently does not support group deconvolution."
|
||||
)
|
||||
caffe_op = BaseTranslate(layer, "ConvTranspose")
|
||||
output = caffe_op.output[0]
|
||||
_TranslateStridePadKernelHelper(param, caffe_op)
|
||||
caffe_op.input.extend([output + '_w'])
|
||||
AddArgument(caffe_op, "order", "NCHW")
|
||||
weight = utils.NumpyArrayToCaffe2Tensor(pretrained_blobs[0], output + '_w')
|
||||
if param.bias_term:
|
||||
bias = utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[1].flatten(), output + '_b'
|
||||
)
|
||||
caffe_op.input.extend([output + '_b'])
|
||||
return caffe_op, [weight, bias]
|
||||
else:
|
||||
return caffe_op, [weight]
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Crop")
|
||||
def TranslateCrop(layer, pretrained_blobs, is_test, **kwargs):
|
||||
net, net_params, input_dims = kwargs['net'], kwargs['net_params'], kwargs['input_dims']
|
||||
n, c, h, w = input_dims
|
||||
dummy_input = np.random.randn(n, c, h, w).astype(np.float32)
|
||||
dim_map = _GetBlobDimMap(net, net_params, dummy_input)
|
||||
param = layer.crop_param
|
||||
axis, offsets = param.axis, param.offset
|
||||
caffe_op = BaseTranslate(layer, "Slice")
|
||||
input_1 = caffe_op.input[1]
|
||||
input_1_dim = dim_map[input_1]
|
||||
starts, ends = [], []
|
||||
dims = len(dim_map[input_1])
|
||||
assert len(offsets) == 1, 'Caffe Translator for Crop only works for offset \
|
||||
of 1 for now'
|
||||
for _ in range(axis):
|
||||
starts.append(0)
|
||||
ends.append(-1)
|
||||
end_offset = [int(offsets[0] + input_1_dim[i]) for i in range(axis, dims)]
|
||||
ends.extend(end_offset)
|
||||
starts.extend([offsets[0]] * len(end_offset))
|
||||
op = caffe2_pb2.OperatorDef()
|
||||
op.input.extend([caffe_op.input[0]])
|
||||
op.output.extend(caffe_op.output)
|
||||
op.arg.extend(caffe_op.arg)
|
||||
op.type = caffe_op.type
|
||||
AddArgument(op, "starts", starts)
|
||||
AddArgument(op, "ends", ends)
|
||||
return op, []
|
||||
|
||||
@TranslatorRegistry.Register("ReLU")
|
||||
def TranslateRelu(layer, pretrained_blobs, is_test, **kwargs):
|
||||
return BaseTranslate(layer, "Relu"), []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Pooling")
|
||||
def TranslatePool(layer, pretrained_blobs, is_test, **kwargs):
|
||||
param = layer.pooling_param
|
||||
if param.pool == caffe_pb2.PoolingParameter.MAX:
|
||||
caffe_op = BaseTranslate(layer, "MaxPool")
|
||||
elif param.pool == caffe_pb2.PoolingParameter.AVE:
|
||||
caffe_op = BaseTranslate(layer, "AveragePool")
|
||||
_TranslateStridePadKernelHelper(param, caffe_op)
|
||||
AddArgument(caffe_op, "order", "NCHW")
|
||||
try:
|
||||
# In the Facebook port of Caffe, a torch_pooling field was added to
|
||||
# map the pooling computation of Torch. Essentially, it uses
|
||||
# floor((height + 2 * padding - kernel) / stride) + 1
|
||||
# instead of
|
||||
# ceil((height + 2 * padding - kernel) / stride) + 1
|
||||
# which is Caffe's version.
|
||||
# Torch pooling is actually the same as Caffe2 pooling, so we don't
|
||||
# need to do anything.
|
||||
is_torch_pooling = param.torch_pooling
|
||||
except AttributeError:
|
||||
is_torch_pooling = False
|
||||
if not is_torch_pooling:
|
||||
AddArgument(caffe_op, "legacy_pad",
|
||||
caffe2_legacy_pb2.CAFFE_LEGACY_POOLING)
|
||||
if param.global_pooling:
|
||||
AddArgument(caffe_op, "global_pooling", 1)
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Pooling3D")
|
||||
def TranslatePool3D(layer, pretrained_blobs, is_test, **kwargs):
|
||||
param = layer.pooling3d_param
|
||||
if param.pool == caffe_pb2.Pooling3DParameter.MAX:
|
||||
caffe_op = BaseTranslate(layer, "MaxPool")
|
||||
|
||||
elif param.pool == caffe_pb2.Pooling3DParameter.AVE:
|
||||
caffe_op = BaseTranslate(layer, "AveragePool")
|
||||
AddArgument(caffe_op, "order", "NCHW")
|
||||
AddArgument(
|
||||
caffe_op,
|
||||
"kernels",
|
||||
[param.kernel_depth, param.kernel_size, param.kernel_size])
|
||||
|
||||
AddArgument(
|
||||
caffe_op,
|
||||
"strides",
|
||||
[param.temporal_stride, param.stride, param.stride])
|
||||
temporal_pad = 0
|
||||
spatial_pad = 0
|
||||
if hasattr(param, 'temporal_pad'):
|
||||
temporal_pad = param.temporal_pad
|
||||
if hasattr(param, 'pad'):
|
||||
spatial_pad = param.pad
|
||||
AddArgument(caffe_op, "pads", [temporal_pad, spatial_pad, spatial_pad] * 2)
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("LRN")
|
||||
def TranslateLRN(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "LRN")
|
||||
caffe_op.output.extend(['_' + caffe_op.output[0] + '_scale'])
|
||||
param = layer.lrn_param
|
||||
if param.norm_region != caffe_pb2.LRNParameter.ACROSS_CHANNELS:
|
||||
raise ValueError(
|
||||
"Does not support norm region other than across channels.")
|
||||
AddArgument(caffe_op, "size", int(param.local_size))
|
||||
AddArgument(caffe_op, "alpha", float(param.alpha))
|
||||
AddArgument(caffe_op, "beta", float(param.beta))
|
||||
AddArgument(caffe_op, "bias", float(param.k))
|
||||
AddArgument(caffe_op, "order", "NCHW")
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("InnerProduct")
|
||||
def TranslateInnerProduct(layer, pretrained_blobs, is_test, **kwargs):
|
||||
param = layer.inner_product_param
|
||||
try:
|
||||
if param.axis != 1 or param.transpose:
|
||||
raise ValueError(
|
||||
"We don't have testing case for non-default axis and transpose "
|
||||
"cases yet so we are disabling it for now. If you have a model "
|
||||
"with this, please do send us your model for us to update this "
|
||||
"support, and you are more than welcome to send a PR for this.")
|
||||
except AttributeError:
|
||||
# We might be using an historic Caffe protobuf that does not have axis
|
||||
# and transpose arguments, so we will silently pass.
|
||||
pass
|
||||
caffe_op = BaseTranslate(layer, "FC")
|
||||
output = caffe_op.output[0]
|
||||
caffe_op.input.extend([output + '_w', output + '_b'])
|
||||
# To provide the old-style 4-dimensional blob (1, 1, dim_output, dim_input)
|
||||
# case, we always explicitly reshape the pretrained blob.
|
||||
if pretrained_blobs[0].ndim not in [2, 4]:
|
||||
raise ValueError("Unexpected weight ndim.")
|
||||
if (pretrained_blobs[0].ndim == 4 and
|
||||
list(pretrained_blobs[0].shape[:2]) != [1, 1]):
|
||||
raise ValueError(
|
||||
"If pretrained blob has 4 dims (old-style Caffe), the first two "
|
||||
"should be of value 1, but I got " + str(pretrained_blobs[0].shape))
|
||||
weight = utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[0].reshape(-1, pretrained_blobs[0].shape[-1]),
|
||||
output + '_w'
|
||||
)
|
||||
bias = utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[1].flatten(), output + '_b'
|
||||
)
|
||||
return caffe_op, [weight, bias]
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Dropout")
|
||||
def TranslateDropout(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "Dropout")
|
||||
caffe_op.output.extend(['_' + caffe_op.output[0] + '_mask'])
|
||||
param = layer.dropout_param
|
||||
AddArgument(caffe_op, "ratio", param.dropout_ratio)
|
||||
if (is_test):
|
||||
AddArgument(caffe_op, "is_test", 1)
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Softmax")
|
||||
def TranslateSoftmax(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "Softmax")
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("SoftmaxWithLoss")
|
||||
def TranslateSoftmaxWithLoss(layer, pretrained_blobs, is_test, **kwargs):
|
||||
softmax_op = core.CreateOperator(
|
||||
"Softmax", [layer.bottom[0]],
|
||||
layer.bottom[0] + "_translator_autogen_softmax")
|
||||
xent_op = core.CreateOperator(
|
||||
"LabelCrossEntropy",
|
||||
[softmax_op.output[0], layer.bottom[1]],
|
||||
layer.bottom[0] + "_translator_autogen_xent")
|
||||
loss_op = core.CreateOperator(
|
||||
"AveragedLoss",
|
||||
xent_op.output[0],
|
||||
layer.top[0])
|
||||
return [softmax_op, xent_op, loss_op], []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Accuracy")
|
||||
def TranslateAccuracy(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "Accuracy")
|
||||
if layer.accuracy_param.top_k != 1:
|
||||
AddArgument(caffe_op, "top_k", layer.accuracy_param.top_k)
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Concat")
|
||||
def TranslateConcat(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "Concat")
|
||||
caffe_op.output.extend(['_' + caffe_op.output[0] + '_dims'])
|
||||
AddArgument(caffe_op, "order", "NCHW")
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("TanH")
|
||||
def TranslateTanH(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "Tanh")
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("InstanceNorm")
|
||||
def TranslateInstanceNorm(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "InstanceNorm")
|
||||
output = caffe_op.output[0]
|
||||
weight = utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[0].flatten(), output + '_w')
|
||||
bias = utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[1].flatten(), output + '_b')
|
||||
caffe_op.input.extend([output + '_w', output + '_b'])
|
||||
AddArgument(caffe_op, "order", "NCHW")
|
||||
return caffe_op, [weight, bias]
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("BatchNorm")
|
||||
def TranslateBatchNorm(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "SpatialBN")
|
||||
output = caffe_op.output[0]
|
||||
param = layer.batch_norm_param
|
||||
AddArgument(caffe_op, "is_test", is_test)
|
||||
AddArgument(caffe_op, "epsilon", param.eps)
|
||||
AddArgument(caffe_op, "order", "NCHW")
|
||||
|
||||
caffe_op.input.extend(
|
||||
[output + "_scale",
|
||||
output + "_bias",
|
||||
output + "_mean",
|
||||
output + "_var"])
|
||||
if not is_test:
|
||||
caffe_op.output.extend(
|
||||
[output + "_mean",
|
||||
output + "_var",
|
||||
output + "_saved_mean",
|
||||
output + "_saved_var"])
|
||||
|
||||
n_channels = pretrained_blobs[0].shape[0]
|
||||
if pretrained_blobs[2][0] != 0:
|
||||
mean = utils.NumpyArrayToCaffe2Tensor(
|
||||
(1. / pretrained_blobs[2][0]) * pretrained_blobs[0],
|
||||
output + '_mean')
|
||||
var = utils.NumpyArrayToCaffe2Tensor(
|
||||
(1. / pretrained_blobs[2][0]) * pretrained_blobs[1],
|
||||
output + '_var')
|
||||
else:
|
||||
raise RuntimeError("scalar is zero.")
|
||||
if len(pretrained_blobs) > 3:
|
||||
# IntelCaffe and NVCaffe uses fused BN+Scale,
|
||||
# three blobs for BN and two blobs for Scale,
|
||||
# so that the total number of blobs becomes five (including scale and bias).
|
||||
scale = utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[3].flatten(),
|
||||
output + '_scale')
|
||||
bias = utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[4].flatten(),
|
||||
output + '_bias')
|
||||
else:
|
||||
pretrained_blobs[2][0] = 1
|
||||
pretrained_blobs[2] = np.tile(pretrained_blobs[2], (n_channels, ))
|
||||
scale = utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[2],
|
||||
output + '_scale')
|
||||
bias = utils.NumpyArrayToCaffe2Tensor(
|
||||
np.zeros_like(pretrained_blobs[2]),
|
||||
output + '_bias')
|
||||
|
||||
return caffe_op, [scale, bias, mean, var]
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Eltwise")
|
||||
def TranslateElementWise(layer, pretrained_blobs, is_test, **kwargs):
|
||||
param = layer.eltwise_param
|
||||
# TODO(jiayq): if we have a protobuf that uses this, lift this constraint
|
||||
# and verify that we can correctly translate.
|
||||
if len(param.coeff) or param.operation != 1:
|
||||
raise RuntimeError("This eltwise layer is not yet supported.")
|
||||
caffe_op = BaseTranslate(layer, "Sum")
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Scale")
|
||||
def TranslateScale(layer, pretrained_blobs, is_test, **kwargs):
|
||||
mul_op = BaseTranslate(layer, "Mul")
|
||||
scale_param = layer.scale_param
|
||||
AddArgument(mul_op, "axis", scale_param.axis)
|
||||
AddArgument(mul_op, "broadcast", True)
|
||||
if len(mul_op.input) == 1:
|
||||
# the scale parameter is in pretrained blobs
|
||||
if scale_param.num_axes != 1:
|
||||
raise RuntimeError("This path has not been verified yet.")
|
||||
|
||||
output = mul_op.output[0]
|
||||
mul_op_param = output + 'scale_w'
|
||||
mul_op.input.append(mul_op_param)
|
||||
weights = []
|
||||
weights.append(utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[0].flatten(), mul_op_param))
|
||||
|
||||
add_op = None
|
||||
if len(pretrained_blobs) == 1:
|
||||
# No bias-term in Scale layer
|
||||
pass
|
||||
elif len(pretrained_blobs) == 2:
|
||||
# Caffe Scale layer supports a bias term such that it computes
|
||||
# (scale_param * X + bias), whereas Caffe2 Mul op doesn't.
|
||||
# Include a separate Add op for the bias followed by Mul.
|
||||
add_op = copy.deepcopy(mul_op)
|
||||
add_op.type = "Add"
|
||||
add_op_param = output + 'scale_b'
|
||||
internal_blob = output + "_internal"
|
||||
del mul_op.output[:]
|
||||
mul_op.output.append(internal_blob)
|
||||
del add_op.input[:]
|
||||
add_op.input.append(internal_blob)
|
||||
add_op.input.append(add_op_param)
|
||||
weights.append(utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[1].flatten(), add_op_param))
|
||||
else:
|
||||
raise RuntimeError("Unexpected number of pretrained blobs in Scale")
|
||||
|
||||
caffe_ops = [mul_op]
|
||||
if add_op:
|
||||
caffe_ops.append(add_op)
|
||||
assert len(caffe_ops) == len(weights)
|
||||
return caffe_ops, weights
|
||||
elif len(mul_op.input) == 2:
|
||||
# TODO(jiayq): find a protobuf that uses this and verify.
|
||||
raise RuntimeError("This path has not been verified yet.")
|
||||
else:
|
||||
raise RuntimeError("Unexpected number of inputs.")
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Reshape")
|
||||
def TranslateReshape(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "Reshape")
|
||||
caffe_op.output.append("_" + caffe_op.input[0] + "_dims")
|
||||
reshape_param = layer.reshape_param
|
||||
AddArgument(caffe_op, 'shape', reshape_param.shape.dim)
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Flatten")
|
||||
def TranslateFlatten(layer, pretrained_blobs, is_test, **kwargs):
|
||||
param = layer.flatten_param
|
||||
if param.end_axis != -1:
|
||||
raise NotImplementedError("flatten_param.end_axis not supported yet.")
|
||||
|
||||
if param.axis == 0:
|
||||
caffe_op = BaseTranslate(layer, "FlattenToVec")
|
||||
elif param.axis == 1:
|
||||
caffe_op = BaseTranslate(layer, "Flatten")
|
||||
else:
|
||||
# This could be a Reshape op, but dim size is not known here.
|
||||
raise NotImplementedError(
|
||||
"Not supported yet for flatten_param.axis {}.".format(param.axis))
|
||||
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Sigmoid")
|
||||
def TranslateSigmoid(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "Sigmoid")
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("ROIPooling")
|
||||
def TranslateROIPooling(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "RoIPool")
|
||||
AddArgument(caffe_op, "order", "NCHW")
|
||||
|
||||
if is_test:
|
||||
AddArgument(caffe_op, "is_test", is_test)
|
||||
else:
|
||||
# Only used for gradient computation
|
||||
caffe_op.output.append(caffe_op.output[0] + '_argmaxes')
|
||||
|
||||
param = layer.roi_pooling_param
|
||||
if param.HasField('pooled_h'):
|
||||
AddArgument(caffe_op, 'pooled_h', param.pooled_h)
|
||||
if param.HasField('pooled_w'):
|
||||
AddArgument(caffe_op, 'pooled_w', param.pooled_w)
|
||||
if param.HasField('spatial_scale'):
|
||||
AddArgument(caffe_op, 'spatial_scale', param.spatial_scale)
|
||||
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("PReLU")
|
||||
def TranslatePRelu(layer, pretrained_blobs, is_test, **kwargs):
|
||||
caffe_op = BaseTranslate(layer, "PRelu")
|
||||
output = caffe_op.output[0]
|
||||
caffe_op.input.extend([output + '_Slope'])
|
||||
slope = utils.NumpyArrayToCaffe2Tensor(pretrained_blobs[0], output + '_Slope')
|
||||
|
||||
return caffe_op, [slope]
|
||||
|
||||
|
||||
@TranslatorRegistry.Register("Reduction")
|
||||
def TranslateReduction(layer, pretrained_blobs, is_test, **kwargs):
|
||||
param = layer.reduction_param
|
||||
if param.operation == caffe_pb2.ReductionParameter.SUM:
|
||||
caffe_op = BaseTranslate(layer, "ReduceBackSum")
|
||||
elif param.operation == caffe_pb2.ReductionParameter.MEAN:
|
||||
caffe_op = BaseTranslate(layer, "ReduceBackMean")
|
||||
else:
|
||||
raise NotImplementedError("Not yet supported")
|
||||
|
||||
if param.axis > 0:
|
||||
# We can't figure out the number of dims to reduce from positive axis
|
||||
# for back reduction since the shape info is not known here.
|
||||
raise NotImplementedError("Not yet supported")
|
||||
num_reduce_dim = -param.axis
|
||||
AddArgument(caffe_op, "num_reduce_dim", num_reduce_dim)
|
||||
|
||||
return caffe_op, []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Utilitity to convert pretrained caffe models to Caffe2 models.")
|
||||
parser.add_argument("prototext", help="Caffe prototext.")
|
||||
parser.add_argument("caffemodel", help="Caffe trained model.")
|
||||
parser.add_argument("--init_net", help="Caffe2 initialization net.",
|
||||
default="init_net.pb")
|
||||
parser.add_argument("--predict_net", help="Caffe2 prediction net.",
|
||||
default="predict_net.pb")
|
||||
parser.add_argument("--remove_legacy_pad", help="Remove legacy pad \
|
||||
(Only works for nets with one input blob)",
|
||||
action="store_true",
|
||||
default=False)
|
||||
parser.add_argument("--input_dims", help="Dimension of input blob", nargs='+',
|
||||
type=int, default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
caffenet = caffe_pb2.NetParameter()
|
||||
caffenet_pretrained = caffe_pb2.NetParameter()
|
||||
input_proto = args.prototext
|
||||
input_caffemodel = args.caffemodel
|
||||
output_init_net = args.init_net
|
||||
output_predict_net = args.predict_net
|
||||
|
||||
with open(input_proto) as f:
|
||||
text_format.Merge(f.read(), caffenet)
|
||||
with open(input_caffemodel, 'rb') as f:
|
||||
caffenet_pretrained.ParseFromString(f.read())
|
||||
net, pretrained_params = TranslateModel(
|
||||
caffenet, caffenet_pretrained, is_test=True,
|
||||
remove_legacy_pad=args.remove_legacy_pad,
|
||||
input_dims=args.input_dims
|
||||
)
|
||||
|
||||
# Assume there is one input and one output
|
||||
external_input = net.op[0].input[0]
|
||||
external_output = net.op[-1].output[0]
|
||||
|
||||
net.external_input.extend([external_input])
|
||||
net.external_input.extend([param.name for param in pretrained_params.protos])
|
||||
net.external_output.extend([external_output])
|
||||
init_net = ConvertTensorProtosToInitNet(pretrained_params, external_input)
|
||||
|
||||
with open(output_predict_net, 'wb') as f:
|
||||
f.write(net.SerializeToString())
|
||||
with open(output_predict_net + 'txt', 'w') as f:
|
||||
f.write(str(net))
|
||||
with open(output_init_net, 'wb') as f:
|
||||
f.write(init_net.SerializeToString())
|
||||
90
caffe2/python/caffe_translator_test.py
Normal file
90
caffe2/python/caffe_translator_test.py
Normal file
@ -0,0 +1,90 @@
|
||||
# This a large test that goes through the translation of the bvlc caffenet
|
||||
# model, runs an example through the whole model, and verifies numerically
|
||||
# that all the results look right. In default, it is disabled unless you
|
||||
# explicitly want to run it.
|
||||
|
||||
from google.protobuf import text_format
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
|
||||
CAFFE_FOUND = False
|
||||
try:
|
||||
from caffe.proto import caffe_pb2
|
||||
from caffe2.python import caffe_translator
|
||||
CAFFE_FOUND = True
|
||||
except Exception as e:
|
||||
# Safeguard so that we only catch the caffe module not found exception.
|
||||
if ("'caffe'" in str(e)):
|
||||
print(
|
||||
"PyTorch/Caffe2 now requires a separate installation of caffe. "
|
||||
"Right now, this is not found, so we will skip the caffe "
|
||||
"translator test.")
|
||||
|
||||
from caffe2.python import utils, workspace, test_util
|
||||
import unittest
|
||||
|
||||
def setUpModule():
|
||||
# Do nothing if caffe and test data is not found
|
||||
if not (CAFFE_FOUND and os.path.exists('data/testdata/caffe_translator')):
|
||||
return
|
||||
# We will do all the computation stuff in the global space.
|
||||
caffenet = caffe_pb2.NetParameter()
|
||||
caffenet_pretrained = caffe_pb2.NetParameter()
|
||||
with open('data/testdata/caffe_translator/deploy.prototxt') as f:
|
||||
text_format.Merge(f.read(), caffenet)
|
||||
with open('data/testdata/caffe_translator/'
|
||||
'bvlc_reference_caffenet.caffemodel') as f:
|
||||
caffenet_pretrained.ParseFromString(f.read())
|
||||
for remove_legacy_pad in [True, False]:
|
||||
net, pretrained_params = caffe_translator.TranslateModel(
|
||||
caffenet, caffenet_pretrained, is_test=True,
|
||||
remove_legacy_pad=remove_legacy_pad
|
||||
)
|
||||
with open('data/testdata/caffe_translator/'
|
||||
'bvlc_reference_caffenet.translatedmodel',
|
||||
'w') as fid:
|
||||
fid.write(str(net))
|
||||
for param in pretrained_params.protos:
|
||||
workspace.FeedBlob(param.name, utils.Caffe2TensorToNumpyArray(param))
|
||||
# Let's also feed in the data from the Caffe test code.
|
||||
data = np.load('data/testdata/caffe_translator/data_dump.npy').astype(
|
||||
np.float32)
|
||||
workspace.FeedBlob('data', data)
|
||||
# Actually running the test.
|
||||
workspace.RunNetOnce(net.SerializeToString())
|
||||
|
||||
|
||||
@unittest.skipIf(not CAFFE_FOUND,
|
||||
'No Caffe installation found.')
|
||||
@unittest.skipIf(not os.path.exists('data/testdata/caffe_translator'),
|
||||
'No testdata existing for the caffe translator test. Exiting.')
|
||||
class TestNumericalEquivalence(test_util.TestCase):
|
||||
def testBlobs(self):
|
||||
names = [
|
||||
"conv1", "pool1", "norm1", "conv2", "pool2", "norm2", "conv3",
|
||||
"conv4", "conv5", "pool5", "fc6", "fc7", "fc8", "prob"
|
||||
]
|
||||
for name in names:
|
||||
print('Verifying {}'.format(name))
|
||||
caffe2_result = workspace.FetchBlob(name)
|
||||
reference = np.load(
|
||||
'data/testdata/caffe_translator/' + name + '_dump.npy'
|
||||
)
|
||||
self.assertEqual(caffe2_result.shape, reference.shape)
|
||||
scale = np.max(caffe2_result)
|
||||
np.testing.assert_almost_equal(
|
||||
caffe2_result / scale,
|
||||
reference / scale,
|
||||
decimal=5
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) == 1:
|
||||
print(
|
||||
'If you do not explicitly ask to run this test, I will not run it. '
|
||||
'Pass in any argument to have the test run for you.'
|
||||
)
|
||||
sys.exit(0)
|
||||
unittest.main()
|
||||
833
caffe2/python/checkpoint.py
Normal file
833
caffe2/python/checkpoint.py
Normal file
@ -0,0 +1,833 @@
|
||||
## @package checkpoint
|
||||
# Module caffe2.python.checkpoint
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import logging
|
||||
from caffe2.python import core, context
|
||||
from caffe2.python.net_builder import ops
|
||||
from caffe2.python.task import (
|
||||
final_output,
|
||||
Node,
|
||||
Task,
|
||||
TaskGroup,
|
||||
TaskOutput,
|
||||
WorkspaceType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class Job(context.Managed):
|
||||
"""
|
||||
A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the
|
||||
`exit_group` which will be run by a JobRunner.
|
||||
|
||||
The `init_group` will be run only once at startup. Its role is to
|
||||
initialize globally persistent blobs such as model weights, accumulators
|
||||
and data file lists.
|
||||
|
||||
The `epoch_group` will be run in a loop after init_group. The loop will
|
||||
exit when any of the stop signals added with `add_stop_condition` is True
|
||||
at the end of an epoch.
|
||||
|
||||
The download_group will be run only once, after all the executions of
|
||||
epoch_group finish. Its role is to collect the distribute scattered
|
||||
parameters back after training.
|
||||
|
||||
The `exit_group` will be run only once at the very end of the job, the
|
||||
role of this group is to save the results of training in the end of the job.
|
||||
|
||||
Jobs are context-driven, so that Tasks can be added to the active Job
|
||||
without having to explicitly pass the job object around.
|
||||
|
||||
Example of usage:
|
||||
|
||||
def build_reader(partitions):
|
||||
with Job.current().init_group:
|
||||
reader = HiveReader(init_reader, ..., partitions)
|
||||
Task(step=init_reader)
|
||||
with Job.current().epoch_group:
|
||||
limited_reader = ReaderWithLimit(reader, num_iter=10000)
|
||||
data_queue = pipe(limited_reader, num_threads=8)
|
||||
Job.current().add_stop_condition(limited_reader.data_finished())
|
||||
return data_queue
|
||||
|
||||
def build_hogwild_trainer(reader, model):
|
||||
with Job.current().init_group:
|
||||
Task(step=model.param_init_net)
|
||||
with Job.current().epoch_group:
|
||||
pipe(reader, processor=model, num_threads=8)
|
||||
with Job.current().exit_group:
|
||||
Task(step=model.save_model_net)
|
||||
|
||||
with Job() as job:
|
||||
reader = build_reader(partitions)
|
||||
model = build_model(params)
|
||||
build_hogwild_trainer(reader, model)
|
||||
"""
|
||||
def __init__(self,
|
||||
init_group=None, epoch_group=None,
|
||||
download_group=None, exit_group=None,
|
||||
stop_conditions=None, nodes_to_checkpoint=None):
|
||||
self.init_group = init_group or TaskGroup(
|
||||
workspace_type=WorkspaceType.GLOBAL)
|
||||
self.epoch_group = epoch_group or TaskGroup()
|
||||
self.download_group = download_group or TaskGroup()
|
||||
self.exit_group = exit_group or TaskGroup()
|
||||
self.stop_conditions = stop_conditions or []
|
||||
self._nodes_to_checkpoint = nodes_to_checkpoint
|
||||
|
||||
def nodes_to_checkpoint(self):
|
||||
if self._nodes_to_checkpoint:
|
||||
return self._nodes_to_checkpoint
|
||||
else:
|
||||
return self.init_group.used_nodes()
|
||||
|
||||
def compile(self, session_class):
|
||||
self._nodes_to_checkpoint = self.nodes_to_checkpoint()
|
||||
self.init_group = session_class.compile(self.init_group)
|
||||
self.epoch_group = session_class.compile(self.epoch_group)
|
||||
self.download_group = session_class.compile(self.download_group)
|
||||
self.exit_group = session_class.compile(self.exit_group)
|
||||
|
||||
def __enter__(self):
|
||||
super().__enter__()
|
||||
self.epoch_group.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.epoch_group.__exit__()
|
||||
super().__exit__(*args)
|
||||
|
||||
def add_stop_condition(self, output):
|
||||
if isinstance(output, core.BlobReference):
|
||||
t = Task(outputs=[output], group=self.epoch_group)
|
||||
output = t.outputs()[0]
|
||||
assert isinstance(output, TaskOutput)
|
||||
self.stop_conditions.append(output)
|
||||
|
||||
|
||||
def get_ckpt_filename(node_name, epoch):
|
||||
"""Returns the checkpoint filename.
|
||||
|
||||
Args:
|
||||
node_name: A string. The name of the node.
|
||||
epoch: An integer. The checkpoint epoch.
|
||||
|
||||
Returns:
|
||||
ckpt_filename: A string. The filename of the checkpoint.
|
||||
"""
|
||||
return node_name + '.' + str(epoch)
|
||||
|
||||
|
||||
def db_name(epoch, node_name, db_prefix, path_prefix=None):
|
||||
"""Returns the full db name where checkpoint files are saved.
|
||||
|
||||
Args:
|
||||
epoch: An integer. The checkpoint epoch.
|
||||
node_name: A string. The name of the node.
|
||||
db_prefix: A string. The prefix used to construct full db name.
|
||||
path_prefix: A string. Optional param used to construct db name or path
|
||||
where checkpoint files are stored.
|
||||
Returns:
|
||||
db_name: A string. The absolute path of full_db_name where checkpoint
|
||||
files are saved
|
||||
"""
|
||||
if path_prefix:
|
||||
db_name = path_prefix + get_ckpt_filename(node_name, epoch)
|
||||
else:
|
||||
ckpt_filename = get_ckpt_filename(node_name, epoch)
|
||||
db_name = os.path.join(db_prefix, ckpt_filename)
|
||||
return db_name
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""
|
||||
Controls saving and loading of workspaces on every epoch boundary of a job.
|
||||
If a CheckpointManager instance is passed to JobRunner, then JobRunner will
|
||||
call `init`, `read` and `save` at different moments in between epoch runs.
|
||||
|
||||
Args:
|
||||
db_prefix: The prefix used to construct full db name. Since `absolute_path`
|
||||
is set to True, this will be used as db_name in SaveOp.
|
||||
node_name: Name of the node where this checkpoint_manager is used.
|
||||
db_type: Type of database to use for storing checkpoint.
|
||||
metadata_handler: An optional object capable of reading/writing
|
||||
checkpoint info in storage of choice.
|
||||
"""
|
||||
|
||||
BLOB_NAMES = "blob_names"
|
||||
|
||||
def __init__(self, db_prefix, node_name, db_type, metadata_handler=None):
|
||||
self._db_prefix = db_prefix
|
||||
self._node_name = node_name
|
||||
self._db_type = db_type
|
||||
self._metadata_handler = metadata_handler
|
||||
# make sure these blobs are the first in the checkpoint file.
|
||||
self._net = core.Net('!!checkpoint_mngr')
|
||||
self._blob_names = self._net.AddExternalInput(self.BLOB_NAMES)
|
||||
self._names_output = None
|
||||
self._path_prefix = None
|
||||
self._path_type = None
|
||||
self._current_db_name = None
|
||||
self._current_checkpoint_duration = None
|
||||
|
||||
"""
|
||||
Initialize the checkpoint manager. Determines all blobs that need to be saved
|
||||
or loads from a checkpoint.
|
||||
|
||||
Args:
|
||||
nodes: An array of nodes where this checkpoint manager is running. Should
|
||||
only contain a single node.
|
||||
retrieve_from_epoch: Set to a number to load blobs from this epoch.
|
||||
path_prefix: Used to construct db name or path where checkpoint files are
|
||||
stored.
|
||||
path_type: Indicate the type of path where checkpoint files are stored.
|
||||
"""
|
||||
def init(
|
||||
self,
|
||||
nodes=None,
|
||||
retrieve_from_epoch=None,
|
||||
path_prefix=None,
|
||||
path_type=None
|
||||
):
|
||||
"""
|
||||
Build a Task that will be run once after the job's `init_group` is run.
|
||||
This task will determine which blobs need to be checkpointed.
|
||||
If retrieve_from_epoch is not None, then the checkpoint metadata is
|
||||
retrieved from a previously saved checkpoint.
|
||||
"""
|
||||
assert nodes is None or len(nodes) == 1, (
|
||||
'CheckpointManager only supports single node.')
|
||||
|
||||
with Task(outputs=[self._blob_names]) as task:
|
||||
if retrieve_from_epoch is None:
|
||||
ops.GetAllBlobNames(
|
||||
[],
|
||||
self._blob_names,
|
||||
include_shared=False)
|
||||
else:
|
||||
full_db_name = db_name(retrieve_from_epoch,
|
||||
self._node_name, self._db_prefix, path_prefix)
|
||||
db_type = path_type or self._db_type
|
||||
logger.info("Initializing checkpoints from = %s"
|
||||
% full_db_name)
|
||||
ops.Load(
|
||||
[], self._blob_names,
|
||||
db=full_db_name,
|
||||
db_type=db_type,
|
||||
absolute_path=True,
|
||||
keep_device=True,
|
||||
)
|
||||
self._names_output = task.outputs()[0]
|
||||
return task
|
||||
|
||||
def blob_list(self):
|
||||
assert self._names_output
|
||||
return self._names_output.fetch().tolist()
|
||||
|
||||
def _timed_task(self, cp_op_name, add_op):
|
||||
"""
|
||||
Build a Task that will measure the time span of checkpoint operations,
|
||||
once operation is done, time can be read from _current_checkpoint_duration.
|
||||
|
||||
Args:
|
||||
cp_op_name: A string name of the checkpoint operation.
|
||||
add_op: A functor to add the checkpoint operation.
|
||||
|
||||
Returns:
|
||||
A task with timer.
|
||||
"""
|
||||
with Task(name=cp_op_name) as task:
|
||||
with ops.task_init():
|
||||
timer = ops.TimerBegin([], counter_name=self._node_name)
|
||||
add_op()
|
||||
with ops.task_exit():
|
||||
time_span_blob = ops.TimerGetAndEnd(timer)
|
||||
self._current_checkpoint_duration = final_output(time_span_blob)
|
||||
return task
|
||||
|
||||
def collect_checkpoint_stats(self, stats):
|
||||
"""
|
||||
Add one checkpoint stats into the stats.
|
||||
|
||||
Args:
|
||||
stats: A dict of checkpoint stats that will be reported.
|
||||
"""
|
||||
if self._current_db_name and self._current_checkpoint_duration:
|
||||
stats[self._current_db_name] = self._current_checkpoint_duration.fetch()[0]
|
||||
else:
|
||||
logger.info(
|
||||
"Failed to collect checkpoint stats: {}".format(
|
||||
self._current_db_name
|
||||
)
|
||||
)
|
||||
|
||||
def load(self, epoch, path_prefix=None, path_type=None):
|
||||
"""
|
||||
Build a Task that will be run by JobRunner when the job is to be
|
||||
resumed from a given epoch. This task will run a Load op that will
|
||||
load and deserialize all relevant blobs from a persistent storage.
|
||||
"""
|
||||
self._current_db_name = db_name(
|
||||
epoch, self._node_name, self._db_prefix, path_prefix
|
||||
)
|
||||
db_type = path_type or self._db_type
|
||||
logger.info("Loading checkpoints from = %s" % self._current_db_name)
|
||||
|
||||
def add_op():
|
||||
ops.Load(
|
||||
[],
|
||||
self.blob_list(),
|
||||
db=self._current_db_name,
|
||||
db_type=db_type,
|
||||
absolute_path=True,
|
||||
keep_device=True,
|
||||
)
|
||||
|
||||
return self._timed_task('checkpoint_load', add_op)
|
||||
|
||||
def load_blobs_from_checkpoint(self, blob_names, epoch):
|
||||
"""
|
||||
Builds a Task that loads only the necessary blobs from a checkpoint of
|
||||
the given epoch. The necessary blobs are given in the blob_names
|
||||
argument.
|
||||
|
||||
Args:
|
||||
blob_names: A list of strings. Each string is the name of a
|
||||
blob.
|
||||
epoch: The checkpoint epoch to load from.
|
||||
|
||||
Returns:
|
||||
A Task which loads the specified blobs from the checkpoint of the
|
||||
given epoch.
|
||||
"""
|
||||
self._current_db_name = db_name(epoch, self._node_name, self._db_prefix)
|
||||
logger.info('Load from %s' % self._current_db_name)
|
||||
|
||||
def add_op():
|
||||
ops.Load(
|
||||
[],
|
||||
blob_names,
|
||||
db=self._current_db_name,
|
||||
db_type=self._db_type,
|
||||
absolute_path=True,
|
||||
allow_incomplete=True)
|
||||
|
||||
return self._timed_task('checkpoint_partial_load', add_op)
|
||||
|
||||
def check_db_exists(self, epoch):
|
||||
logger.info('Check existence of %s' %
|
||||
db_name(epoch, self._node_name, self._db_prefix))
|
||||
with Task() as task:
|
||||
existence = ops.Const(False)
|
||||
ops.DBExists(
|
||||
[],
|
||||
[existence],
|
||||
db_name=db_name(epoch, self._node_name, self._db_prefix),
|
||||
db_type=self._db_type,
|
||||
absolute_path=True)
|
||||
task.add_output(existence)
|
||||
return task
|
||||
|
||||
def report_checkpoint_stats(self, action_name):
|
||||
"""
|
||||
Report checkpoint operation stats for current node.
|
||||
|
||||
Args:
|
||||
action_name: A string of the name of checkpoint operation.
|
||||
"""
|
||||
all_stats = {}
|
||||
self.collect_checkpoint_stats(all_stats)
|
||||
if self._metadata_handler:
|
||||
self._metadata_handler.report(action_name, all_stats)
|
||||
|
||||
def save(self, epoch):
|
||||
"""
|
||||
Build a Task that is run once after `init_group` and after each
|
||||
epoch is run. This will execute a Save ops to serialize and persist
|
||||
blobs present in the global workspace.
|
||||
"""
|
||||
self._current_db_name = db_name(epoch, self._node_name, self._db_prefix)
|
||||
logger.info('Saving to %s' % self._current_db_name)
|
||||
|
||||
def add_op():
|
||||
ops.Save(
|
||||
self.blob_list(), [],
|
||||
db=self._current_db_name,
|
||||
db_type=self._db_type,
|
||||
absolute_path=True)
|
||||
|
||||
return self._timed_task('checkpoint_save', add_op)
|
||||
|
||||
def write_checkpoint_metadata(self, epoch):
|
||||
"""
|
||||
Write metadata for checkpoint
|
||||
|
||||
Args:
|
||||
epoch: An integer. The epoch-id for which checkpoint metadata is
|
||||
written
|
||||
"""
|
||||
if self._metadata_handler is not None:
|
||||
self._metadata_handler.write(epoch=epoch)
|
||||
|
||||
def get_resume_from_epoch_id(self, user_epoch=None):
|
||||
"""
|
||||
Identify the epoch-id from which Job must resume
|
||||
|
||||
Args:
|
||||
user_epoch: An integer. Optional parameter for user to explicitly
|
||||
identify the epoch-id to load checkpoint from
|
||||
Returns:
|
||||
epoch: the epoch-id to load checkpoints from
|
||||
or None if no checkpoints were written
|
||||
"""
|
||||
last_epoch = user_epoch
|
||||
if self._metadata_handler is not None:
|
||||
last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
|
||||
return last_epoch
|
||||
|
||||
def set_params(self, nodes, path_prefix=None, path_type=None):
|
||||
"""Set parameters associated with CP manager
|
||||
|
||||
Args:
|
||||
nodes: An array of nodes where this checkpoint manager is running.
|
||||
path_prefix: Used to construct db name or path where checkpoint files are
|
||||
stored.
|
||||
path_type: Indicate the type of path where checkpoint files are stored.
|
||||
"""
|
||||
if path_prefix:
|
||||
self._path_prefix = path_prefix
|
||||
if path_type:
|
||||
self._path_type = path_type
|
||||
if self._metadata_handler:
|
||||
self._metadata_handler.set_params(
|
||||
db_prefix=self._db_prefix,
|
||||
db_type=self._db_type,
|
||||
node_names=[str(self._node_name)],
|
||||
path_prefix=self._path_prefix,
|
||||
path_type=self._path_type)
|
||||
|
||||
def cp_accessible(self, epoch=None):
|
||||
"""Returns True if Checkpoint data is accessible
|
||||
|
||||
Args:
|
||||
epoch: An integer. The epoch of the checkpoint. If None,
|
||||
it implies we need to check if checkpoint directory is accessible
|
||||
|
||||
Returns:
|
||||
is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
|
||||
"""
|
||||
if self._metadata_handler is not None:
|
||||
return self._metadata_handler.cp_accessible(epoch)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
class MultiNodeCheckpointManager:
|
||||
"""
|
||||
Coordinates checkpointing and checkpointing across multiple nodes.
|
||||
Each of `init`, `load` and `save` will build TaskGroups which will
|
||||
trigger checkpointing on each of the nodes involved in a distributed job.
|
||||
|
||||
Args:
|
||||
db_prefix: The prefix used to construct full db name. Since `absolute_path`
|
||||
is set to True, this will be used as db_name in SaveOp.
|
||||
db_type: Type of database to use for storing checkpoint.
|
||||
metadata_handler: An optional object capable of reading/writing
|
||||
checkpoint info in storage of choice.
|
||||
"""
|
||||
def __init__(self, db_prefix, db_type, metadata_handler=None):
|
||||
self._node_managers = None
|
||||
self._db_prefix = db_prefix
|
||||
self._db_type = db_type
|
||||
self._metadata_handler = metadata_handler
|
||||
self._path_prefix = None
|
||||
self._path_type = None
|
||||
|
||||
def _task_group(self, func, *args, **kw):
|
||||
assert self._node_managers is not None, 'init must be called first.'
|
||||
with TaskGroup(WorkspaceType.GLOBAL) as task_group:
|
||||
for node, manager in self._node_managers:
|
||||
with Node(node):
|
||||
func(manager, *args, **kw)
|
||||
return task_group
|
||||
|
||||
"""
|
||||
Args:
|
||||
nodes: An array of nodes where this checkpoint manager is running.
|
||||
retrieve_from_epoch: Set to a number to load blobs from this epoch.
|
||||
path_prefix: Used to construct db name or path where checkpoint files are
|
||||
stored.
|
||||
path_type: Indicate the type of path where checkpoint files are stored.
|
||||
"""
|
||||
def init(
|
||||
self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None
|
||||
):
|
||||
if self._node_managers is not None:
|
||||
assert [node for node, _ in self._node_managers] == nodes
|
||||
return TaskGroup(WorkspaceType.GLOBAL)
|
||||
self._node_managers = []
|
||||
for node in nodes:
|
||||
with Node(node):
|
||||
manager = CheckpointManager(
|
||||
db_prefix=self._db_prefix,
|
||||
node_name=str(node),
|
||||
db_type=self._db_type)
|
||||
self._node_managers.append((node, manager))
|
||||
return self._task_group(
|
||||
CheckpointManager.init,
|
||||
nodes=[node],
|
||||
retrieve_from_epoch=retrieve_from_epoch,
|
||||
path_prefix=path_prefix,
|
||||
path_type=path_type)
|
||||
|
||||
def load(self, epoch, path_prefix=None, path_type=None):
|
||||
return self._task_group(
|
||||
CheckpointManager.load,
|
||||
epoch,
|
||||
path_prefix=path_prefix,
|
||||
path_type=path_type)
|
||||
|
||||
def load_blobs_locally(self, nodes, blob_names, epoch, session):
|
||||
"""Loads the necessary blobs from the checkpoints to the current node.
|
||||
|
||||
Args:
|
||||
blob_names: A list of strings. Each string is the name of a
|
||||
blob.
|
||||
epoch: An integer. The checkpoint epoch to load from.
|
||||
session: A Session object to execute the Load ops.
|
||||
"""
|
||||
if self._node_managers is not None:
|
||||
assert [node for node, _ in self._node_managers] == nodes
|
||||
else:
|
||||
self._node_managers = []
|
||||
for node in nodes:
|
||||
with Node(node):
|
||||
manager = CheckpointManager(
|
||||
db_prefix=self._db_prefix,
|
||||
node_name=str(node),
|
||||
db_type=self._db_type)
|
||||
self._node_managers.append((node, manager))
|
||||
assert self._node_managers is not None, 'must initialize node managers'
|
||||
for _, manager in self._node_managers:
|
||||
existence_task = manager.check_db_exists(epoch)
|
||||
session.run(existence_task)
|
||||
existence = existence_task.outputs()[0].fetch()
|
||||
if not existence:
|
||||
logger.info('DB %s does not exist!' %
|
||||
db_name(epoch, manager._node_name, manager._db_prefix))
|
||||
return False
|
||||
load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
|
||||
session.run(load_task)
|
||||
logger.info('Successfully loaded from checkpoints.')
|
||||
return True
|
||||
|
||||
def get_ckpt_db_name(self, node_name, epoch):
|
||||
"""Returns the DB name of the given node and the given epoch.
|
||||
|
||||
The DB name is effectively the checkpoint path of the given node and
|
||||
the given epoch.
|
||||
|
||||
Args:
|
||||
node_name: A string. The node name of interest.
|
||||
epoch: An integer. The epoch of the checkpoint.
|
||||
|
||||
Returns:
|
||||
checkpoint_db_name: A string. The checkpoint path of the given
|
||||
node and the given epoch.
|
||||
"""
|
||||
for node, manager in self._node_managers:
|
||||
if str(node) == node_name:
|
||||
return db_name(epoch, manager._node_name, manager._db_prefix)
|
||||
|
||||
def report_checkpoint_stats(self, action_name):
|
||||
"""
|
||||
Report the checkpoint stats for all the nodes, we need to aggregate all
|
||||
the node's stats together so that we know which node's checkpoint
|
||||
operation dominates.
|
||||
|
||||
Args:
|
||||
action_name: A string of the name of checkpoint operation.
|
||||
"""
|
||||
all_stats = {}
|
||||
for _, manager in self._node_managers:
|
||||
manager.collect_checkpoint_stats(all_stats)
|
||||
logger.debug("checkpoint stats: {}".format(all_stats))
|
||||
if self._metadata_handler:
|
||||
self._metadata_handler.report(action_name, all_stats)
|
||||
|
||||
def save(self, epoch):
|
||||
"""
|
||||
Build a Task that will execute a Save ops to serialize and persist
|
||||
blobs present in the global workspace.
|
||||
"""
|
||||
return self._task_group(CheckpointManager.save, epoch)
|
||||
|
||||
def write_checkpoint_metadata(self, epoch):
|
||||
"""
|
||||
Write metadata for checkpoint
|
||||
|
||||
Args:
|
||||
epoch: An integer. The epoch-id for which checkpoint metadata is
|
||||
written
|
||||
"""
|
||||
if self._metadata_handler is not None:
|
||||
self._metadata_handler.write(epoch=epoch)
|
||||
|
||||
def get_resume_from_epoch_id(self, user_epoch=None):
|
||||
"""
|
||||
Identify the epoch-id from which Job must resume
|
||||
|
||||
Args:
|
||||
user_epoch: An integer. Optional parameter for user to explicitly
|
||||
identify the epoch-id to load checkpoint from
|
||||
Returns:
|
||||
epoch: the epoch-id to load checkpoints from
|
||||
or None if no checkpoints were written
|
||||
"""
|
||||
last_epoch = user_epoch
|
||||
if self._metadata_handler is not None:
|
||||
last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
|
||||
return last_epoch
|
||||
|
||||
def set_params(self, nodes, path_prefix=None, path_type=None):
|
||||
"""Set parameters associated with CP manager
|
||||
|
||||
Args:
|
||||
nodes: An array of nodes where this checkpoint manager is running.
|
||||
path_prefix: Used to construct db name or path where checkpoint files are
|
||||
stored.
|
||||
path_type: Indicate the type of path where checkpoint files are stored.
|
||||
"""
|
||||
self._node_names = [str(node) for node in nodes]
|
||||
if path_prefix:
|
||||
self._path_prefix = path_prefix
|
||||
if path_type:
|
||||
self._path_type = path_type
|
||||
if self._metadata_handler:
|
||||
self._metadata_handler.set_params(
|
||||
db_prefix=self._db_prefix,
|
||||
db_type=self._db_type,
|
||||
node_names=self._node_names,
|
||||
path_prefix=self._path_prefix,
|
||||
path_type=self._path_type)
|
||||
|
||||
def cp_accessible(self, epoch=None):
|
||||
"""Returns True if Checkpoint data is accessible
|
||||
|
||||
Args:
|
||||
epoch: An integer. The epoch of the checkpoint. If None,
|
||||
it implies we need to check if checkpoint directory is accessible
|
||||
|
||||
Returns:
|
||||
is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
|
||||
"""
|
||||
if self._metadata_handler is not None:
|
||||
return self._metadata_handler.cp_accessible(epoch)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
class UploadTaskGroupBuilder:
|
||||
"""A simple class to upload checkpoints."""
|
||||
def build(self, epoch, checkpoint_manager):
|
||||
"""Builds the task group to upload checkpoints.
|
||||
|
||||
Args:
|
||||
epoch: An integer. The checkpoint epoch to be uploaded.
|
||||
checkpoint_manager: Can be a CheckpointManager for single machine
|
||||
or a MultiNodeCheckpointManager for multi-machine. The manager
|
||||
that initializes/saves/loads checkpoints.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This base class only has the interface,
|
||||
the implementation will be in the subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class JobRunner:
|
||||
"""
|
||||
Implement the runtime logic for jobs with checkpointing at the level of
|
||||
epoch. Can be used to run either single-host or distributed jobs. Job
|
||||
runner is a callable to be called once from the master, passing a session
|
||||
as an argument. This call will block until the Job execution is complete.
|
||||
|
||||
If a checkpoint_manager is passed, checkpoints will be taken after
|
||||
initialization and after each epoch execution. If, in addition,
|
||||
`resume_from_epoch` is an epoch number, the corresponding checkpoint will
|
||||
be loaded and job execution will continue from the given epoch. In
|
||||
this case, the job's init_group will not be run.
|
||||
|
||||
Refer to checkpoint_test.py for an example.
|
||||
"""
|
||||
def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None,
|
||||
upload_task_group_builder=None):
|
||||
"""Initializes the JobRunner.
|
||||
|
||||
Args:
|
||||
job: A Job object. The job to be executed.
|
||||
checkpoint_manager: Can be a CheckpointManager for single machine
|
||||
or a MultiNodeCheckpointManager for multi-machine. The manager
|
||||
that initializes/saves/loads checkpoints.
|
||||
resume_from_epoch: An integer. The epoch to resume from.
|
||||
upload_task_group_builder: A subclass of the
|
||||
UploadTaskGroupBuilder. Creates a task group to upload
|
||||
checkpoints.
|
||||
"""
|
||||
self.resume_from_epoch = resume_from_epoch
|
||||
self.checkpoint_manager = checkpoint_manager
|
||||
self.job = job
|
||||
self.upload_task_group_builder = upload_task_group_builder
|
||||
|
||||
def train(self, session):
|
||||
"""Runs the training flow.
|
||||
|
||||
Args:
|
||||
session: A Session object. Valid choises are: LocalSession,
|
||||
LocalHostScheduler, and DistributedSession. It is used to
|
||||
execute one TaskGroup a time.
|
||||
"""
|
||||
# identify the epoch we must resume from
|
||||
if self.checkpoint_manager:
|
||||
self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
|
||||
self.resume_from_epoch = self.checkpoint_manager.\
|
||||
get_resume_from_epoch_id(self.resume_from_epoch)
|
||||
if self.resume_from_epoch is not None:
|
||||
logger.info('Resuming from epoch {}'.format(self.resume_from_epoch))
|
||||
|
||||
# Initialize all the nodes.
|
||||
from_scratch = self.resume_from_epoch is None
|
||||
if from_scratch:
|
||||
session.run(self.job.init_group)
|
||||
|
||||
if self.checkpoint_manager:
|
||||
logger.info('Preparing checkpoints ...')
|
||||
session.run(self.checkpoint_manager.init(
|
||||
self.job.nodes_to_checkpoint(),
|
||||
retrieve_from_epoch=self.resume_from_epoch))
|
||||
# Save the first checkpoint before training starts, or resume from
|
||||
# a previously saved checkpoint.
|
||||
if from_scratch:
|
||||
self.save_checkpoints(0, session)
|
||||
else:
|
||||
logger.info('Loading checkpoints for epoch {} ...'.format(
|
||||
self.resume_from_epoch))
|
||||
session.run(
|
||||
self.checkpoint_manager.load(self.resume_from_epoch))
|
||||
self.checkpoint_manager.report_checkpoint_stats('checkpoint_load')
|
||||
logger.info('Checkpoint loaded')
|
||||
|
||||
logger.info("Finished initializing")
|
||||
|
||||
# Start training.
|
||||
epoch = 1 if from_scratch else self.resume_from_epoch + 1
|
||||
while True:
|
||||
logger.info('Starting epoch %d' % epoch)
|
||||
session.run(self.job.epoch_group)
|
||||
logger.info('Finished epoch %d' % epoch)
|
||||
stop_conditions = [o.fetch() for o in self.job.stop_conditions]
|
||||
|
||||
if self.checkpoint_manager:
|
||||
self.save_checkpoints(epoch, session)
|
||||
|
||||
if any(stop_conditions):
|
||||
logger.info('Stopping')
|
||||
break
|
||||
epoch += 1
|
||||
logger.info('Finished training')
|
||||
# Upload the checkpoints.
|
||||
if (self.upload_task_group_builder):
|
||||
upload_task_group = self.upload_task_group_builder.build(
|
||||
epoch, self.checkpoint_manager)
|
||||
session.run(upload_task_group)
|
||||
logger.info('Finished uploading the checkpoints')
|
||||
|
||||
# Download the parameters to save
|
||||
session.run(self.job.download_group)
|
||||
logger.info('Finished downloading the parameters')
|
||||
|
||||
# Finally run the exit step to save nets
|
||||
session.run(self.job.exit_group)
|
||||
logger.info('Finished running the exit group')
|
||||
return epoch
|
||||
|
||||
def load_blobs_from_checkpoints(self, blob_names, epoch, session):
|
||||
"""Loads the necessary blobs from the checkpoints.
|
||||
|
||||
Checkpoints store the snapshots of the workspace in each node.
|
||||
Sometimes we only need to load a subset of the blobs from the
|
||||
checkpoints. One common scenario is to load only the model blobs from
|
||||
the checkpoints for evaluation purpose. Given the names of the
|
||||
necessary blobs, this function goes over all the checkpoints of all the
|
||||
nodes, but only loads the blobs specified in the blob_names to the
|
||||
current workspace.
|
||||
|
||||
Args:
|
||||
blob_names: A list of strings. Each string is the name of a
|
||||
blob.
|
||||
epoch: An integer. The checkpoint epoch to load from.
|
||||
session: A Session object to execute the load ops.
|
||||
|
||||
Raises:
|
||||
ValueError: When the checkpoint manager is invalid.
|
||||
"""
|
||||
if not self.checkpoint_manager:
|
||||
raise ValueError('Checkpoint manager is None')
|
||||
logger.info('Loading checkpoint for epoch {} ...'.format(epoch))
|
||||
result = self.checkpoint_manager.load_blobs_locally(
|
||||
self.job.nodes_to_checkpoint(), blob_names, epoch, session)
|
||||
self.checkpoint_manager.report_checkpoint_stats('checkpoint_partial_load')
|
||||
return result
|
||||
|
||||
def save_checkpoints(self, epoch, session):
|
||||
"""Triggers operation to save checkpoints
|
||||
|
||||
This method will trigger the Save ops to serialize and persist the
|
||||
blobs present in the global workspaace.
|
||||
|
||||
Args:
|
||||
epoch: An integer. The checkpoint epoch-id that we are saving.
|
||||
session: A Session object to execute the save ops.
|
||||
|
||||
Raises:
|
||||
ValueError: When the checkpoint manager is invalid.
|
||||
"""
|
||||
if not self.checkpoint_manager:
|
||||
raise ValueError('Checkpoint manager is None')
|
||||
try:
|
||||
is_accessible = self.checkpoint_manager.cp_accessible(epoch=None)
|
||||
if is_accessible:
|
||||
logger.info('Saving checkpoints for epoch {}'.format(epoch))
|
||||
session.run(self.checkpoint_manager.save(epoch))
|
||||
self.checkpoint_manager.write_checkpoint_metadata(epoch)
|
||||
logger.info('Checkpoints saved')
|
||||
self.checkpoint_manager.report_checkpoint_stats('checkpoint_save')
|
||||
else:
|
||||
logger.warning("Checkpoint files cannot be accessed!")
|
||||
except Exception as ex:
|
||||
logger.warning("Unable to write checkpoint for epoch {}. Error={}".
|
||||
format(epoch, ex))
|
||||
|
||||
|
||||
def epoch_limiter(job, num_epochs):
|
||||
"""
|
||||
Creates a task that will output True when a given
|
||||
number of epochs has finished.
|
||||
"""
|
||||
with job.init_group:
|
||||
init_net = core.Net('epoch_counter_init')
|
||||
counter = init_net.CreateCounter([], init_count=num_epochs - 1)
|
||||
Task(step=init_net)
|
||||
|
||||
with job.epoch_group:
|
||||
epoch_net = core.Net('epoch_countdown')
|
||||
finished = epoch_net.CountDown(counter)
|
||||
output = Task(step=epoch_net, outputs=finished).outputs()[0]
|
||||
job.add_stop_condition(output)
|
||||
338
caffe2/python/checkpoint_test.py
Normal file
338
caffe2/python/checkpoint_test.py
Normal file
@ -0,0 +1,338 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python.schema import Struct, ConstRecord
|
||||
from caffe2.python import core, workspace, model_helper
|
||||
from caffe2.python.session import LocalSession
|
||||
from caffe2.python.dataset import Dataset
|
||||
from caffe2.python.pipeline import pipe
|
||||
from caffe2.python.checkpoint import (
|
||||
CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner, epoch_limiter,
|
||||
UploadTaskGroupBuilder, db_name)
|
||||
from caffe2.python.net_builder import ops
|
||||
from caffe2.python.task import Node, Task, TaskGroup, WorkspaceType, Cluster
|
||||
from caffe2.python.test_util import TestCase
|
||||
from caffe2.python.dataio import ReaderWithLimit
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
|
||||
def build_pipeline(node_id):
|
||||
with Node('trainer_%d' % node_id):
|
||||
with Job.current().init_group, Task():
|
||||
data_arr = Struct(('val', np.array(list(range(10)))))
|
||||
data = ConstRecord(ops, data_arr)
|
||||
ds = Dataset(data, name='dataset:%d' % node_id)
|
||||
full_reader = ds.reader(ops)
|
||||
total = ops.Const([100])
|
||||
|
||||
def inc_total(rec):
|
||||
ops.Add([total, rec.val()], [total])
|
||||
|
||||
epoch_reader = ReaderWithLimit(full_reader, num_iter=3)
|
||||
pipe(epoch_reader, processor=inc_total)
|
||||
Job.current().add_stop_condition(epoch_reader.data_finished())
|
||||
return [total]
|
||||
|
||||
|
||||
EXPECTED_TOTALS = [103, 115, 136, 145]
|
||||
|
||||
|
||||
def local_copy_op(src, dest):
|
||||
def copy_op(inputs, outputs):
|
||||
shutil.copyfile(src, dest)
|
||||
return copy_op
|
||||
|
||||
|
||||
class UploadToLocalFile(UploadTaskGroupBuilder):
|
||||
def __init__(self, dest_dir):
|
||||
self.dest_dir = dest_dir
|
||||
|
||||
def build(self, epoch, checkpoint_manager):
|
||||
with TaskGroup(WorkspaceType.GLOBAL) as upload_task_group:
|
||||
for node, manager in checkpoint_manager._node_managers:
|
||||
with Node(str(node)), Task():
|
||||
src_path = db_name(epoch, manager._node_name, manager._db_prefix)
|
||||
dest_path = os.path.join(self.dest_dir, str(node))
|
||||
ops.Python((local_copy_op,
|
||||
[src_path, dest_path], {}))([], [])
|
||||
return upload_task_group
|
||||
|
||||
|
||||
class TestCheckpoint(TestCase):
|
||||
def run_with(self, builder):
|
||||
with Cluster():
|
||||
with Job() as job:
|
||||
outputs = build_pipeline(node_id=0)
|
||||
output_fetcher = Task(step=core.Net('empty'), outputs=outputs)
|
||||
|
||||
def fetch_total(session):
|
||||
session.run(output_fetcher)
|
||||
return output_fetcher.outputs()[0].fetch()
|
||||
|
||||
session, checkpoint = builder()
|
||||
job.compile(LocalSession)
|
||||
num_epochs = JobRunner(job, checkpoint).train(session)
|
||||
self.assertEqual(num_epochs, len(EXPECTED_TOTALS))
|
||||
self.assertEqual(fetch_total(session), EXPECTED_TOTALS[-1])
|
||||
|
||||
for initial_epoch in range(1, num_epochs + 1):
|
||||
session, checkpoint = builder()
|
||||
JobRunner(
|
||||
job,
|
||||
checkpoint, resume_from_epoch=initial_epoch
|
||||
).train(session)
|
||||
self.assertEqual(fetch_total(session), EXPECTED_TOTALS[-1])
|
||||
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
session.run(checkpoint.load(epoch))
|
||||
self.assertEqual(fetch_total(session),
|
||||
EXPECTED_TOTALS[epoch - 1])
|
||||
|
||||
def test_single_checkpoint(self):
|
||||
# test single node
|
||||
try:
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def builder():
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
checkpoint = CheckpointManager(tmpdir, 'temp_node', 'minidb')
|
||||
return session, checkpoint
|
||||
|
||||
self.run_with(builder)
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
# test multi-node
|
||||
try:
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def builder():
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
||||
return session, checkpoint
|
||||
|
||||
self.run_with(builder)
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def test_ckpt_name_and_load_model_from_ckpts(self):
|
||||
try:
|
||||
num_nodes = 3
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
# First, check if the checkpoint name generation mechanism is
|
||||
# correct.
|
||||
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
||||
with Cluster():
|
||||
with Job() as job:
|
||||
for node_id in range(num_nodes):
|
||||
build_pipeline(node_id)
|
||||
job.compile(LocalSession)
|
||||
checkpoint.init(job.nodes_to_checkpoint())
|
||||
|
||||
for node_id in range(num_nodes):
|
||||
epoch = 5
|
||||
node_name = 'trainer_%d' % node_id
|
||||
expected_db_name = tmpdir + '/' + node_name + '.5'
|
||||
self.assertEqual(
|
||||
checkpoint.get_ckpt_db_name(node_name, epoch),
|
||||
expected_db_name)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
# Next, check mechanism to load model from checkpoints.
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
workspace.ResetWorkspace()
|
||||
for node_id in range(num_nodes):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
||||
with Cluster():
|
||||
with Job() as job:
|
||||
build_pipeline(node_id)
|
||||
job.compile(LocalSession)
|
||||
job_runner = JobRunner(job, checkpoint)
|
||||
num_epochs = job_runner.train(session)
|
||||
self.assertEqual(num_epochs, len(EXPECTED_TOTALS))
|
||||
|
||||
# There are 17 global blobs after finishing up the job runner.
|
||||
# (only blobs on init_group are checkpointed)
|
||||
self.assertEqual(len(ws.blobs), 17)
|
||||
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
self.assertEqual(len(ws.blobs), 0)
|
||||
model_blob_names = ['trainer_1/task_2/GivenTensorInt64Fill:0',
|
||||
'trainer_2/task_2/GivenTensorInt64Fill:0']
|
||||
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
||||
with Cluster():
|
||||
with Job() as job:
|
||||
for node_id in range(num_nodes):
|
||||
build_pipeline(node_id)
|
||||
job.compile(LocalSession)
|
||||
job_runner = JobRunner(job, checkpoint)
|
||||
job_runner.load_blobs_from_checkpoints(
|
||||
blob_names=model_blob_names, epoch=1, session=session)
|
||||
|
||||
# Check that we can successfully load from checkpoints of epochs
|
||||
# 1 to 4, but not epoch 5.
|
||||
for epoch in range(1, 5):
|
||||
self.assertTrue(
|
||||
job_runner.load_blobs_from_checkpoints(
|
||||
blob_names=model_blob_names, epoch=epoch,
|
||||
session=session))
|
||||
# Check that all the model blobs are loaded.
|
||||
for blob_name in model_blob_names:
|
||||
self.assertTrue(ws.has_blob(blob_name))
|
||||
self.assertEqual(
|
||||
ws.fetch_blob(blob_name),
|
||||
np.array([EXPECTED_TOTALS[epoch - 1]]))
|
||||
self.assertFalse(
|
||||
job_runner.load_blobs_from_checkpoints(
|
||||
blob_names=model_blob_names, epoch=5, session=session))
|
||||
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def test_upload_checkpoint(self):
|
||||
try:
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
upload_dir = os.path.join(tmpdir, "upload")
|
||||
os.mkdir(upload_dir)
|
||||
num_nodes = 3
|
||||
|
||||
# The uploaded files do not exist yet.
|
||||
for node_id in range(num_nodes):
|
||||
node_name = 'trainer_%d' % node_id
|
||||
upload_path = os.path.join(upload_dir, node_name)
|
||||
self.assertFalse(os.path.exists(upload_path))
|
||||
|
||||
# Create and run the job runner.
|
||||
for node_id in range(3):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
||||
with Cluster():
|
||||
with Job() as job:
|
||||
build_pipeline(node_id)
|
||||
job.compile(LocalSession)
|
||||
local_upload_builder = UploadToLocalFile(upload_dir)
|
||||
job_runner = JobRunner(
|
||||
job, checkpoint,
|
||||
upload_task_group_builder=local_upload_builder)
|
||||
num_epochs = job_runner.train(session)
|
||||
self.assertEqual(num_epochs, len(EXPECTED_TOTALS))
|
||||
|
||||
# The uploaded files should exist now.
|
||||
for node_id in range(num_nodes):
|
||||
node_name = 'trainer_%d' % node_id
|
||||
upload_path = os.path.join(upload_dir, node_name)
|
||||
self.assertTrue(os.path.exists(upload_path))
|
||||
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def test_ckpt_save_failure(self):
|
||||
num_nodes = 3
|
||||
# The goal of this test is to ensure that the job runs
|
||||
# successfully even if saving a checkpoint fails.
|
||||
# Hence tmpdir is a non existent directory to emulate a failure
|
||||
# while saving checkpoints
|
||||
tmpdir = "/tmp/path_does_not_exist/"
|
||||
|
||||
# Check the saving checkpoint failure does not cause job failure
|
||||
workspace.ResetWorkspace()
|
||||
for node_id in range(num_nodes):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
||||
with Cluster():
|
||||
with Job() as job:
|
||||
build_pipeline(node_id)
|
||||
job.compile(LocalSession)
|
||||
job_runner = JobRunner(job, checkpoint)
|
||||
num_epochs = job_runner.train(session)
|
||||
# make sure all epochs are executed even though saving the checkpoint failed
|
||||
# Saving checkpoint failure should not cause job failure
|
||||
self.assertEqual(num_epochs, len(EXPECTED_TOTALS))
|
||||
|
||||
def test_download_group_simple(self):
|
||||
"""
|
||||
A simple test that ensures we have download task group
|
||||
executed between epoch_group and exit_group.
|
||||
"""
|
||||
model = model_helper.ModelHelper(name="test_model")
|
||||
download_net = core.Net("download_net")
|
||||
|
||||
for name in ["input1", "input2", "output", "download_result"]:
|
||||
model.param_init_net.ConstantFill([],
|
||||
[name],
|
||||
shape=[8, ],
|
||||
value=1.0,
|
||||
run_once=0)
|
||||
model.net.Add(["input1", "input2"], ["output"])
|
||||
download_net.Copy(["output"], ["download_result"])
|
||||
|
||||
# All blob values are initialized as 1.0, after download_net executed
|
||||
# we expect to see download result is the same as training result.
|
||||
with Job() as job:
|
||||
with Node("trainer:0"):
|
||||
with job.init_group:
|
||||
Task(step=model.param_init_net)
|
||||
with job.epoch_group:
|
||||
with Task():
|
||||
with ops.loop(1):
|
||||
ops.net(model.net)
|
||||
with job.download_group:
|
||||
Task(step=download_net)
|
||||
|
||||
epoch_limiter(job, 1)
|
||||
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
job_runner = JobRunner(job)
|
||||
job_runner.train(session)
|
||||
|
||||
expected_result = np.full(8, 2.0).astype(np.float32)
|
||||
self.assertTrue(np.array_equal(expected_result,
|
||||
ws.fetch_blob("output")))
|
||||
self.assertTrue(np.array_equal(expected_result,
|
||||
ws.fetch_blob("download_result")))
|
||||
|
||||
def test_reuse_checkpoint_manager(self):
|
||||
"""
|
||||
A simple test that ensures we can reuse a MultiNodeCheckpointManager
|
||||
object.
|
||||
"""
|
||||
try:
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
||||
|
||||
with Job() as job:
|
||||
outputs = build_pipeline(node_id=0)
|
||||
output_fetcher = Task(step=core.Net('empty'), outputs=outputs)
|
||||
job.compile(LocalSession)
|
||||
|
||||
def fetch_total(session):
|
||||
session.run(output_fetcher)
|
||||
return output_fetcher.outputs()[0].fetch()
|
||||
|
||||
num_epochs = JobRunner(job, checkpoint).train(session)
|
||||
for initial_epoch in range(1, num_epochs + 1):
|
||||
JobRunner(
|
||||
job,
|
||||
checkpoint,
|
||||
resume_from_epoch=initial_epoch
|
||||
).train(session)
|
||||
self.assertEqual(fetch_total(session), EXPECTED_TOTALS[-1])
|
||||
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
15
caffe2/python/clean_workspace_test.py
Normal file
15
caffe2/python/clean_workspace_test.py
Normal file
@ -0,0 +1,15 @@
|
||||
import unittest
|
||||
|
||||
from caffe2.python import workspace
|
||||
|
||||
|
||||
# This test is extracted out from workspace_test.py because it relies on the pristine
|
||||
# state of the initial workspace. When tests are run in different orders, this test may
|
||||
# become flaky because of global state modifications impacting what the root folder is
|
||||
# after a reset.
|
||||
class TestWorkspace(unittest.TestCase):
|
||||
def testRootFolder(self):
|
||||
self.assertEqual(workspace.ResetWorkspace(), True)
|
||||
self.assertEqual(workspace.RootFolder(), ".")
|
||||
self.assertEqual(workspace.ResetWorkspace("/tmp/caffe-workspace-test"), True)
|
||||
self.assertEqual(workspace.RootFolder(), "/tmp/caffe-workspace-test")
|
||||
240
caffe2/python/cnn.py
Normal file
240
caffe2/python/cnn.py
Normal file
@ -0,0 +1,240 @@
|
||||
## @package cnn
|
||||
# Module caffe2.python.cnn
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import brew, workspace
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
from caffe2.proto import caffe2_pb2
|
||||
import logging
|
||||
|
||||
|
||||
class CNNModelHelper(ModelHelper):
|
||||
"""A helper model so we can write CNN models more easily, without having to
|
||||
manually define parameter initializations and operators separately.
|
||||
"""
|
||||
|
||||
def __init__(self, order="NCHW", name=None,
|
||||
use_cudnn=True, cudnn_exhaustive_search=False,
|
||||
ws_nbytes_limit=None, init_params=True,
|
||||
skip_sparse_optim=False,
|
||||
param_model=None):
|
||||
logging.warning(
|
||||
"[====DEPRECATE WARNING====]: you are creating an "
|
||||
"object from CNNModelHelper class which will be deprecated soon. "
|
||||
"Please use ModelHelper object with brew module. For more "
|
||||
"information, please refer to caffe2.ai and python/brew.py, "
|
||||
"python/brew_test.py for more information."
|
||||
)
|
||||
|
||||
cnn_arg_scope = {
|
||||
'order': order,
|
||||
'use_cudnn': use_cudnn,
|
||||
'cudnn_exhaustive_search': cudnn_exhaustive_search,
|
||||
}
|
||||
if ws_nbytes_limit:
|
||||
cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit
|
||||
super().__init__(
|
||||
skip_sparse_optim=skip_sparse_optim,
|
||||
name="CNN" if name is None else name,
|
||||
init_params=init_params,
|
||||
param_model=param_model,
|
||||
arg_scope=cnn_arg_scope,
|
||||
)
|
||||
|
||||
self.order = order
|
||||
self.use_cudnn = use_cudnn
|
||||
self.cudnn_exhaustive_search = cudnn_exhaustive_search
|
||||
self.ws_nbytes_limit = ws_nbytes_limit
|
||||
if self.order != "NHWC" and self.order != "NCHW":
|
||||
raise ValueError(
|
||||
"Cannot understand the CNN storage order %s." % self.order
|
||||
)
|
||||
|
||||
def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs):
|
||||
return brew.image_input(
|
||||
self,
|
||||
blob_in,
|
||||
blob_out,
|
||||
order=self.order,
|
||||
use_gpu_transform=use_gpu_transform,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def VideoInput(self, blob_in, blob_out, **kwargs):
|
||||
return brew.video_input(
|
||||
self,
|
||||
blob_in,
|
||||
blob_out,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def PadImage(self, blob_in, blob_out, **kwargs):
|
||||
# TODO(wyiming): remove this dummy helper later
|
||||
self.net.PadImage(blob_in, blob_out, **kwargs)
|
||||
|
||||
def ConvNd(self, *args, **kwargs):
|
||||
return brew.conv_nd(
|
||||
self,
|
||||
*args,
|
||||
use_cudnn=self.use_cudnn,
|
||||
order=self.order,
|
||||
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
||||
ws_nbytes_limit=self.ws_nbytes_limit,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def Conv(self, *args, **kwargs):
|
||||
return brew.conv(
|
||||
self,
|
||||
*args,
|
||||
use_cudnn=self.use_cudnn,
|
||||
order=self.order,
|
||||
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
||||
ws_nbytes_limit=self.ws_nbytes_limit,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def ConvTranspose(self, *args, **kwargs):
|
||||
return brew.conv_transpose(
|
||||
self,
|
||||
*args,
|
||||
use_cudnn=self.use_cudnn,
|
||||
order=self.order,
|
||||
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
||||
ws_nbytes_limit=self.ws_nbytes_limit,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def GroupConv(self, *args, **kwargs):
|
||||
return brew.group_conv(
|
||||
self,
|
||||
*args,
|
||||
use_cudnn=self.use_cudnn,
|
||||
order=self.order,
|
||||
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
||||
ws_nbytes_limit=self.ws_nbytes_limit,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def GroupConv_Deprecated(self, *args, **kwargs):
|
||||
return brew.group_conv_deprecated(
|
||||
self,
|
||||
*args,
|
||||
use_cudnn=self.use_cudnn,
|
||||
order=self.order,
|
||||
cudnn_exhaustive_search=self.cudnn_exhaustive_search,
|
||||
ws_nbytes_limit=self.ws_nbytes_limit,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def FC(self, *args, **kwargs):
|
||||
return brew.fc(self, *args, **kwargs)
|
||||
|
||||
def PackedFC(self, *args, **kwargs):
|
||||
return brew.packed_fc(self, *args, **kwargs)
|
||||
|
||||
def FC_Prune(self, *args, **kwargs):
|
||||
return brew.fc_prune(self, *args, **kwargs)
|
||||
|
||||
def FC_Decomp(self, *args, **kwargs):
|
||||
return brew.fc_decomp(self, *args, **kwargs)
|
||||
|
||||
def FC_Sparse(self, *args, **kwargs):
|
||||
return brew.fc_sparse(self, *args, **kwargs)
|
||||
|
||||
def Dropout(self, *args, **kwargs):
|
||||
return brew.dropout(
|
||||
self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
|
||||
)
|
||||
|
||||
def LRN(self, *args, **kwargs):
|
||||
return brew.lrn(
|
||||
self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
|
||||
)
|
||||
|
||||
def Softmax(self, *args, **kwargs):
|
||||
return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs)
|
||||
|
||||
def SpatialBN(self, *args, **kwargs):
|
||||
return brew.spatial_bn(self, *args, order=self.order, **kwargs)
|
||||
|
||||
def SpatialGN(self, *args, **kwargs):
|
||||
return brew.spatial_gn(self, *args, order=self.order, **kwargs)
|
||||
|
||||
def InstanceNorm(self, *args, **kwargs):
|
||||
return brew.instance_norm(self, *args, order=self.order, **kwargs)
|
||||
|
||||
def Relu(self, *args, **kwargs):
|
||||
return brew.relu(
|
||||
self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
|
||||
)
|
||||
|
||||
def PRelu(self, *args, **kwargs):
|
||||
return brew.prelu(self, *args, **kwargs)
|
||||
|
||||
def Concat(self, *args, **kwargs):
|
||||
return brew.concat(self, *args, order=self.order, **kwargs)
|
||||
|
||||
def DepthConcat(self, *args, **kwargs):
|
||||
"""The old depth concat function - we should move to use concat."""
|
||||
print("DepthConcat is deprecated. use Concat instead.")
|
||||
return self.Concat(*args, **kwargs)
|
||||
|
||||
def Sum(self, *args, **kwargs):
|
||||
return brew.sum(self, *args, **kwargs)
|
||||
|
||||
def Transpose(self, *args, **kwargs):
|
||||
return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs)
|
||||
|
||||
def Iter(self, *args, **kwargs):
|
||||
return brew.iter(self, *args, **kwargs)
|
||||
|
||||
def Accuracy(self, *args, **kwargs):
|
||||
return brew.accuracy(self, *args, **kwargs)
|
||||
|
||||
def MaxPool(self, *args, **kwargs):
|
||||
return brew.max_pool(
|
||||
self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
|
||||
)
|
||||
|
||||
def MaxPoolWithIndex(self, *args, **kwargs):
|
||||
return brew.max_pool_with_index(self, *args, order=self.order, **kwargs)
|
||||
|
||||
def AveragePool(self, *args, **kwargs):
|
||||
return brew.average_pool(
|
||||
self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def XavierInit(self):
|
||||
return ('XavierFill', {})
|
||||
|
||||
def ConstantInit(self, value):
|
||||
return ('ConstantFill', dict(value=value))
|
||||
|
||||
@property
|
||||
def MSRAInit(self):
|
||||
return ('MSRAFill', {})
|
||||
|
||||
@property
|
||||
def ZeroInit(self):
|
||||
return ('ConstantFill', {})
|
||||
|
||||
def AddWeightDecay(self, weight_decay):
|
||||
return brew.add_weight_decay(self, weight_decay)
|
||||
|
||||
@property
|
||||
def CPU(self):
|
||||
device_option = caffe2_pb2.DeviceOption()
|
||||
device_option.device_type = caffe2_pb2.CPU
|
||||
return device_option
|
||||
|
||||
@property
|
||||
def GPU(self, gpu_id=0):
|
||||
device_option = caffe2_pb2.DeviceOption()
|
||||
device_option.device_type = workspace.GpuDeviceType
|
||||
device_option.device_id = gpu_id
|
||||
return device_option
|
||||
106
caffe2/python/context.py
Normal file
106
caffe2/python/context.py
Normal file
@ -0,0 +1,106 @@
|
||||
## @package context
|
||||
# Module caffe2.python.context
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
import functools
|
||||
|
||||
|
||||
class _ContextInfo:
|
||||
def __init__(self, cls, allow_default):
|
||||
self.cls = cls
|
||||
self.allow_default = allow_default
|
||||
self._local_stack = threading.local()
|
||||
|
||||
@property
|
||||
def _stack(self):
|
||||
if not hasattr(self._local_stack, 'obj'):
|
||||
self._local_stack.obj = []
|
||||
return self._local_stack.obj
|
||||
|
||||
def enter(self, value):
|
||||
self._stack.append(value)
|
||||
|
||||
def exit(self, value):
|
||||
assert len(self._stack) > 0, 'Context %s is empty.' % self.cls
|
||||
assert self._stack.pop() == value
|
||||
|
||||
def get_active(self, required=True):
|
||||
if len(self._stack) == 0:
|
||||
if not required:
|
||||
return None
|
||||
assert self.allow_default, (
|
||||
'Context %s is required but none is active.' % self.cls)
|
||||
self.enter(self.cls())
|
||||
return self._stack[-1]
|
||||
|
||||
|
||||
class _ContextRegistry:
|
||||
def __init__(self):
|
||||
self._ctxs = {}
|
||||
|
||||
def get(self, cls):
|
||||
if cls not in self._ctxs:
|
||||
assert issubclass(cls, Managed), "must be a context managed class, got {}".format(cls)
|
||||
self._ctxs[cls] = _ContextInfo(cls, allow_default=issubclass(cls, DefaultManaged))
|
||||
return self._ctxs[cls]
|
||||
|
||||
|
||||
_CONTEXT_REGISTRY = _ContextRegistry()
|
||||
|
||||
|
||||
def _context_registry():
|
||||
global _CONTEXT_REGISTRY
|
||||
return _CONTEXT_REGISTRY
|
||||
|
||||
|
||||
def _get_managed_classes(obj):
|
||||
return [
|
||||
cls for cls in inspect.getmro(obj.__class__)
|
||||
if issubclass(cls, Managed) and cls != Managed and cls != DefaultManaged
|
||||
]
|
||||
|
||||
|
||||
|
||||
class Managed:
|
||||
"""
|
||||
Managed makes the inheritted class a context managed class.
|
||||
|
||||
class Foo(Managed): ...
|
||||
|
||||
with Foo() as f:
|
||||
assert f == Foo.current()
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def current(cls, value=None, required=True):
|
||||
ctx_info = _context_registry().get(cls)
|
||||
if value is not None:
|
||||
assert isinstance(value, cls), (
|
||||
'Wrong context type. Expected: %s, got %s.' % (cls, type(value)))
|
||||
return value
|
||||
return ctx_info.get_active(required=required)
|
||||
|
||||
def __enter__(self):
|
||||
for cls in _get_managed_classes(self):
|
||||
_context_registry().get(cls).enter(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
for cls in _get_managed_classes(self):
|
||||
_context_registry().get(cls).exit(self)
|
||||
|
||||
def __call__(self, func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
class DefaultManaged(Managed):
|
||||
"""
|
||||
DefaultManaged is similar to Managed but if there is no parent when
|
||||
current() is called it makes a new one.
|
||||
"""
|
||||
pass
|
||||
13
caffe2/python/context.pyi
Normal file
13
caffe2/python/context.pyi
Normal file
@ -0,0 +1,13 @@
|
||||
from typing import Optional, TypeVar, Type
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
class Managed:
|
||||
@classmethod
|
||||
def current(cls: Type[_T], value: Optional[_T] = None, required: bool = True) -> _T: ...
|
||||
|
||||
def __call__(self, func: _T) -> _T: ...
|
||||
|
||||
def __enter__(self: _T) -> _T: ...
|
||||
|
||||
class DefaultManaged(Managed): ...
|
||||
67
caffe2/python/context_test.py
Normal file
67
caffe2/python/context_test.py
Normal file
@ -0,0 +1,67 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import context, test_util
|
||||
from threading import Thread
|
||||
|
||||
|
||||
class MyContext(context.Managed):
|
||||
pass
|
||||
|
||||
class DefaultMyContext(context.DefaultManaged):
|
||||
pass
|
||||
|
||||
class ChildMyContext(MyContext):
|
||||
pass
|
||||
|
||||
|
||||
class TestContext(test_util.TestCase):
|
||||
def use_my_context(self):
|
||||
try:
|
||||
for _ in range(100):
|
||||
with MyContext() as a:
|
||||
for _ in range(100):
|
||||
self.assertTrue(MyContext.current() == a)
|
||||
except Exception as e:
|
||||
self._exceptions.append(e)
|
||||
|
||||
def testMultiThreaded(self):
|
||||
threads = []
|
||||
self._exceptions = []
|
||||
for _ in range(8):
|
||||
thread = Thread(target=self.use_my_context)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
for t in threads:
|
||||
t.join()
|
||||
for e in self._exceptions:
|
||||
raise e
|
||||
|
||||
@MyContext()
|
||||
def testDecorator(self):
|
||||
self.assertIsNotNone(MyContext.current())
|
||||
|
||||
def testNonDefaultCurrent(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
MyContext.current()
|
||||
|
||||
ctx = MyContext()
|
||||
self.assertEqual(MyContext.current(value=ctx), ctx)
|
||||
|
||||
self.assertIsNone(MyContext.current(required=False))
|
||||
|
||||
def testDefaultCurrent(self):
|
||||
self.assertIsInstance(DefaultMyContext.current(), DefaultMyContext)
|
||||
|
||||
def testNestedContexts(self):
|
||||
with MyContext() as ctx1:
|
||||
with DefaultMyContext() as ctx2:
|
||||
self.assertEqual(DefaultMyContext.current(), ctx2)
|
||||
self.assertEqual(MyContext.current(), ctx1)
|
||||
|
||||
def testChildClasses(self):
|
||||
with ChildMyContext() as ctx:
|
||||
self.assertEqual(ChildMyContext.current(), ctx)
|
||||
self.assertEqual(MyContext.current(), ctx)
|
||||
574
caffe2/python/control.py
Normal file
574
caffe2/python/control.py
Normal file
@ -0,0 +1,574 @@
|
||||
## @package control
|
||||
# Module caffe2.python.control
|
||||
"""
|
||||
Implement functions for controlling execution of nets and steps, including
|
||||
Do
|
||||
DoParallel
|
||||
For-loop
|
||||
While-loop
|
||||
Do-While-loop
|
||||
Switch
|
||||
If
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import core
|
||||
|
||||
|
||||
# Used to generate names of the steps created by the control functions.
|
||||
# It is actually the internal index of these steps.
|
||||
_current_idx = 1
|
||||
_used_step_names = set()
|
||||
|
||||
|
||||
def _get_next_step_name(control_name, base_name):
|
||||
global _current_idx, _used_step_names
|
||||
concat_name = '%s/%s' % (base_name, control_name)
|
||||
next_name = concat_name
|
||||
while next_name in _used_step_names:
|
||||
next_name = '%s_%d' % (concat_name, _current_idx)
|
||||
_current_idx += 1
|
||||
_used_step_names.add(next_name)
|
||||
return next_name
|
||||
|
||||
|
||||
def _MakeList(input):
|
||||
""" input is a tuple.
|
||||
Example:
|
||||
(a, b, c) --> [a, b, c]
|
||||
(a) --> [a]
|
||||
([a, b, c]) --> [a, b, c]
|
||||
"""
|
||||
if len(input) == 0:
|
||||
raise ValueError(
|
||||
'input cannot be empty.')
|
||||
elif len(input) == 1:
|
||||
output = input[0]
|
||||
if not isinstance(output, list):
|
||||
output = [output]
|
||||
else:
|
||||
output = list(input)
|
||||
return output
|
||||
|
||||
|
||||
def _IsNets(nets_or_steps):
|
||||
if isinstance(nets_or_steps, list):
|
||||
return all(isinstance(n, core.Net) for n in nets_or_steps)
|
||||
else:
|
||||
return isinstance(nets_or_steps, core.Net)
|
||||
|
||||
|
||||
def _PrependNets(nets_or_steps, *nets):
|
||||
nets_or_steps = _MakeList((nets_or_steps,))
|
||||
nets = _MakeList(nets)
|
||||
if _IsNets(nets_or_steps):
|
||||
return nets + nets_or_steps
|
||||
else:
|
||||
return [Do('prepend', nets)] + nets_or_steps
|
||||
|
||||
|
||||
def _AppendNets(nets_or_steps, *nets):
|
||||
nets_or_steps = _MakeList((nets_or_steps,))
|
||||
nets = _MakeList(nets)
|
||||
if _IsNets(nets_or_steps):
|
||||
return nets_or_steps + nets
|
||||
else:
|
||||
return nets_or_steps + [Do('append', nets)]
|
||||
|
||||
|
||||
def GetConditionBlobFromNet(condition_net):
|
||||
"""
|
||||
The condition blob is the last external_output that must
|
||||
be a single bool
|
||||
"""
|
||||
assert len(condition_net.Proto().external_output) > 0, (
|
||||
"Condition net %s must has at least one external output" %
|
||||
condition_net.Proto.name)
|
||||
# we need to use a blob reference here instead of a string
|
||||
# otherwise, it will add another name_scope to the input later
|
||||
# when we create new ops (such as OR of two inputs)
|
||||
return core.BlobReference(condition_net.Proto().external_output[-1])
|
||||
|
||||
|
||||
def BoolNet(*blobs_with_bool_value):
|
||||
"""A net assigning constant bool values to blobs. It is mainly used for
|
||||
initializing condition blobs, for example, in multi-task learning, we
|
||||
need to access reader_done blobs before reader_net run. In that case,
|
||||
the reader_done blobs must be initialized.
|
||||
|
||||
Args:
|
||||
blobs_with_bool_value: one or more (blob, bool_value) pairs. The net will
|
||||
assign each bool_value to the corresponding blob.
|
||||
|
||||
returns
|
||||
bool_net: A net assigning constant bool values to blobs.
|
||||
|
||||
Examples:
|
||||
- BoolNet((blob_1, bool_value_1), ..., (blob_n, bool_value_n))
|
||||
- BoolNet([(blob_1, net1), ..., (blob_n, bool_value_n)])
|
||||
- BoolNet((cond_1, bool_value_1))
|
||||
"""
|
||||
blobs_with_bool_value = _MakeList(blobs_with_bool_value)
|
||||
bool_net = core.Net('bool_net')
|
||||
for blob, bool_value in blobs_with_bool_value:
|
||||
out_blob = bool_net.ConstantFill(
|
||||
[],
|
||||
[blob],
|
||||
shape=[],
|
||||
value=bool_value,
|
||||
dtype=core.DataType.BOOL)
|
||||
bool_net.AddExternalOutput(out_blob)
|
||||
|
||||
return bool_net
|
||||
|
||||
|
||||
def NotNet(condition_blob_or_net):
|
||||
"""Not of a condition blob or net
|
||||
|
||||
Args:
|
||||
condition_blob_or_net can be either blob or net. If condition_blob_or_net
|
||||
is Net, the condition is its last external_output
|
||||
that must be a single bool.
|
||||
|
||||
returns
|
||||
not_net: the net NOT the input
|
||||
out_blob: the output blob of the not_net
|
||||
"""
|
||||
if isinstance(condition_blob_or_net, core.Net):
|
||||
condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
|
||||
else:
|
||||
condition_blob = condition_blob_or_net
|
||||
|
||||
not_net = core.Net('not_net')
|
||||
out_blob = not_net.Not(condition_blob)
|
||||
not_net.AddExternalOutput(out_blob)
|
||||
|
||||
return not_net, out_blob
|
||||
|
||||
|
||||
def _CopyConditionBlobNet(condition_blob):
|
||||
"""Make a condition net that copies the condition_blob
|
||||
|
||||
Args:
|
||||
condition_blob is a single bool.
|
||||
|
||||
returns
|
||||
not_net: the net NOT the input
|
||||
out_blob: the output blob of the not_net
|
||||
"""
|
||||
condition_net = core.Net('copy_condition_blob_net')
|
||||
out_blob = condition_net.Copy(condition_blob)
|
||||
condition_net.AddExternalOutput(out_blob)
|
||||
|
||||
return condition_net, out_blob
|
||||
|
||||
|
||||
def MergeConditionNets(name, condition_nets, relation):
|
||||
"""
|
||||
Merge multi condition nets into a single condition nets.
|
||||
|
||||
Args:
|
||||
name: name of the new condition net.
|
||||
condition_nets: a list of condition nets. The last external_output
|
||||
of each condition net must be single bool value.
|
||||
relation: can be 'And' or 'Or'.
|
||||
|
||||
Returns:
|
||||
- A new condition net. Its last external output is relation of all
|
||||
condition_nets.
|
||||
"""
|
||||
if not isinstance(condition_nets, list):
|
||||
return condition_nets
|
||||
if len(condition_nets) <= 1:
|
||||
return condition_nets[0] if condition_nets else None
|
||||
|
||||
merged_net = core.Net(name)
|
||||
for i in range(len(condition_nets)):
|
||||
net_proto = condition_nets[i].Proto()
|
||||
assert net_proto.device_option == merged_net.Proto().device_option
|
||||
assert net_proto.type == merged_net.Proto().type
|
||||
merged_net.Proto().op.extend(net_proto.op)
|
||||
merged_net.Proto().external_input.extend(net_proto.external_input)
|
||||
# discard external outputs as we're combining them together
|
||||
curr_cond = GetConditionBlobFromNet(condition_nets[i])
|
||||
if i == 0:
|
||||
last_cond = curr_cond
|
||||
else:
|
||||
last_cond = merged_net.__getattr__(relation)([last_cond, curr_cond])
|
||||
# merge attributes
|
||||
for k, v in condition_nets[i]._attr_dict.items():
|
||||
merged_net._attr_dict[k] += v
|
||||
|
||||
merged_net.AddExternalOutput(last_cond)
|
||||
|
||||
return merged_net
|
||||
|
||||
|
||||
def CombineConditions(name, condition_nets, relation):
|
||||
"""
|
||||
Combine conditions of multi nets into a single condition nets. Unlike
|
||||
MergeConditionNets, the actual body of condition_nets is not copied into
|
||||
the combine condition net.
|
||||
|
||||
One example is about multi readers. Each reader net has a reader_done
|
||||
condition. When we want to check whether all readers are done, we can
|
||||
use this function to build a new net.
|
||||
|
||||
Args:
|
||||
name: name of the new condition net.
|
||||
condition_nets: a list of condition nets. The last external_output
|
||||
of each condition net must be single bool value.
|
||||
relation: can be 'And' or 'Or'.
|
||||
|
||||
Returns:
|
||||
- A new condition net. Its last external output is relation of all
|
||||
condition_nets.
|
||||
"""
|
||||
if not condition_nets:
|
||||
return None
|
||||
if not isinstance(condition_nets, list):
|
||||
raise ValueError('condition_nets must be a list of nets.')
|
||||
|
||||
if len(condition_nets) == 1:
|
||||
condition_blob = GetConditionBlobFromNet(condition_nets[0])
|
||||
condition_net, _ = _CopyConditionBlobNet(condition_blob)
|
||||
return condition_net
|
||||
|
||||
combined_net = core.Net(name)
|
||||
for i in range(len(condition_nets)):
|
||||
curr_cond = GetConditionBlobFromNet(condition_nets[i])
|
||||
if i == 0:
|
||||
last_cond = curr_cond
|
||||
else:
|
||||
last_cond = combined_net.__getattr__(relation)(
|
||||
[last_cond, curr_cond])
|
||||
|
||||
combined_net.AddExternalOutput(last_cond)
|
||||
|
||||
return combined_net
|
||||
|
||||
|
||||
def Do(name, *nets_or_steps):
|
||||
"""
|
||||
Execute the sequence of nets or steps once.
|
||||
|
||||
Examples:
|
||||
- Do('myDo', net1, net2, ..., net_n)
|
||||
- Do('myDo', list_of_nets)
|
||||
- Do('myDo', step1, step2, ..., step_n)
|
||||
- Do('myDo', list_of_steps)
|
||||
"""
|
||||
nets_or_steps = _MakeList(nets_or_steps)
|
||||
if (len(nets_or_steps) == 1 and isinstance(
|
||||
nets_or_steps[0], core.ExecutionStep)):
|
||||
return nets_or_steps[0]
|
||||
else:
|
||||
return core.scoped_execution_step(
|
||||
_get_next_step_name('Do', name), nets_or_steps)
|
||||
|
||||
|
||||
def DoParallel(name, *nets_or_steps):
|
||||
"""
|
||||
Execute the nets or steps in parallel, waiting for all of them to finish
|
||||
|
||||
Examples:
|
||||
- DoParallel('pDo', net1, net2, ..., net_n)
|
||||
- DoParallel('pDo', list_of_nets)
|
||||
- DoParallel('pDo', step1, step2, ..., step_n)
|
||||
- DoParallel('pDo', list_of_steps)
|
||||
"""
|
||||
nets_or_steps = _MakeList(nets_or_steps)
|
||||
if (len(nets_or_steps) == 1 and isinstance(
|
||||
nets_or_steps[0], core.ExecutionStep)):
|
||||
return nets_or_steps[0]
|
||||
else:
|
||||
return core.scoped_execution_step(
|
||||
_get_next_step_name('DoParallel', name),
|
||||
nets_or_steps,
|
||||
concurrent_substeps=True)
|
||||
|
||||
|
||||
def _RunOnceIf(name, condition_blob_or_net, nets_or_steps):
|
||||
"""
|
||||
Execute nets_or_steps once if condition_blob_or_net evaluates as true.
|
||||
|
||||
If condition_blob_or_net is Net, the condition is its last external_output
|
||||
that must be a single bool. And this net will be executed before
|
||||
nets_or_steps so as to get the condition.
|
||||
"""
|
||||
condition_not_net, stop_blob = NotNet(condition_blob_or_net)
|
||||
if isinstance(condition_blob_or_net, core.Net):
|
||||
nets_or_steps = _PrependNets(
|
||||
nets_or_steps, condition_blob_or_net, condition_not_net)
|
||||
else:
|
||||
nets_or_steps = _PrependNets(nets_or_steps, condition_not_net)
|
||||
|
||||
def if_step(control_name):
|
||||
return core.scoped_execution_step(
|
||||
_get_next_step_name(control_name, name),
|
||||
nets_or_steps,
|
||||
should_stop_blob=stop_blob,
|
||||
only_once=True,
|
||||
)
|
||||
|
||||
if _IsNets(nets_or_steps):
|
||||
bool_net = BoolNet((stop_blob, False))
|
||||
return Do(name + '/_RunOnceIf',
|
||||
bool_net, if_step('_RunOnceIf-inner'))
|
||||
else:
|
||||
return if_step('_RunOnceIf')
|
||||
|
||||
|
||||
def _RunOnceIfNot(name, condition_blob_or_net, nets_or_steps):
|
||||
"""
|
||||
Similar to _RunOnceIf() but Execute nets_or_steps once if
|
||||
condition_blob_or_net evaluates as false.
|
||||
"""
|
||||
if isinstance(condition_blob_or_net, core.Net):
|
||||
condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
|
||||
nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net)
|
||||
else:
|
||||
copy_net, condition_blob = _CopyConditionBlobNet(condition_blob_or_net)
|
||||
nets_or_steps = _PrependNets(nets_or_steps, copy_net)
|
||||
|
||||
return core.scoped_execution_step(
|
||||
_get_next_step_name('_RunOnceIfNot', name),
|
||||
nets_or_steps,
|
||||
should_stop_blob=condition_blob,
|
||||
only_once=True,
|
||||
)
|
||||
|
||||
|
||||
def For(name, nets_or_steps, iter_num):
|
||||
"""
|
||||
Execute nets_or_steps iter_num times.
|
||||
|
||||
Args:
|
||||
nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or
|
||||
a list nets.
|
||||
iter_num: the number times to execute the nets_or_steps.
|
||||
|
||||
Returns:
|
||||
A ExecutionStep instance.
|
||||
"""
|
||||
init_net = core.Net('init-net')
|
||||
iter_cnt = init_net.CreateCounter([], init_count=iter_num)
|
||||
iter_net = core.Net('For-iter')
|
||||
iter_done = iter_net.CountDown([iter_cnt])
|
||||
|
||||
for_step = core.scoped_execution_step(
|
||||
_get_next_step_name('For-inner', name),
|
||||
_PrependNets(nets_or_steps, iter_net),
|
||||
should_stop_blob=iter_done)
|
||||
return Do(name + '/For',
|
||||
Do(name + '/For-init-net', init_net),
|
||||
for_step)
|
||||
|
||||
|
||||
def While(name, condition_blob_or_net, nets_or_steps):
|
||||
"""
|
||||
Execute nets_or_steps when condition_blob_or_net returns true.
|
||||
|
||||
Args:
|
||||
condition_blob_or_net: If it is an instance of Net, its last
|
||||
external_output must be a single bool.
|
||||
nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or
|
||||
a list nets.
|
||||
|
||||
Returns:
|
||||
A ExecutionStep instance.
|
||||
"""
|
||||
condition_not_net, stop_blob = NotNet(condition_blob_or_net)
|
||||
if isinstance(condition_blob_or_net, core.Net):
|
||||
nets_or_steps = _PrependNets(
|
||||
nets_or_steps, condition_blob_or_net, condition_not_net)
|
||||
else:
|
||||
nets_or_steps = _PrependNets(nets_or_steps, condition_not_net)
|
||||
|
||||
def while_step(control_name):
|
||||
return core.scoped_execution_step(
|
||||
_get_next_step_name(control_name, name),
|
||||
nets_or_steps,
|
||||
should_stop_blob=stop_blob,
|
||||
)
|
||||
|
||||
if _IsNets(nets_or_steps):
|
||||
# In this case, while_step has sub-nets:
|
||||
# [condition_blob_or_net, condition_not_net, nets_or_steps]
|
||||
# If stop_blob is pre-set to True (this may happen when While() is
|
||||
# called twice), the loop will exit after executing
|
||||
# condition_blob_or_net. So we use BootNet to set stop_blob to
|
||||
# False.
|
||||
bool_net = BoolNet((stop_blob, False))
|
||||
return Do(name + '/While', bool_net, while_step('While-inner'))
|
||||
else:
|
||||
return while_step('While')
|
||||
|
||||
|
||||
def Until(name, condition_blob_or_net, nets_or_steps):
|
||||
"""
|
||||
Similar to While() but execute nets_or_steps when
|
||||
condition_blob_or_net returns false
|
||||
"""
|
||||
if isinstance(condition_blob_or_net, core.Net):
|
||||
stop_blob = GetConditionBlobFromNet(condition_blob_or_net)
|
||||
nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net)
|
||||
else:
|
||||
stop_blob = core.BlobReference(str(condition_blob_or_net))
|
||||
|
||||
return core.scoped_execution_step(
|
||||
_get_next_step_name('Until', name),
|
||||
nets_or_steps,
|
||||
should_stop_blob=stop_blob)
|
||||
|
||||
|
||||
def DoWhile(name, condition_blob_or_net, nets_or_steps):
|
||||
"""
|
||||
Execute nets_or_steps when condition_blob_or_net returns true. It will
|
||||
execute nets_or_steps before evaluating condition_blob_or_net.
|
||||
|
||||
Args:
|
||||
condition_blob_or_net: if it is an instance of Net, tts last external_output
|
||||
must be a single bool.
|
||||
nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or
|
||||
a list nets.
|
||||
|
||||
Returns:
|
||||
A ExecutionStep instance.
|
||||
"""
|
||||
condition_not_net, stop_blob = NotNet(condition_blob_or_net)
|
||||
if isinstance(condition_blob_or_net, core.Net):
|
||||
nets_or_steps = _AppendNets(
|
||||
nets_or_steps, condition_blob_or_net, condition_not_net)
|
||||
else:
|
||||
nets_or_steps = _AppendNets(nets_or_steps, condition_not_net)
|
||||
|
||||
# If stop_blob is pre-set to True (this may happen when DoWhile() is
|
||||
# called twice), the loop will exit after executing the first net/step
|
||||
# in nets_or_steps. This is not what we want. So we use BootNet to
|
||||
# set stop_blob to False.
|
||||
bool_net = BoolNet((stop_blob, False))
|
||||
return Do(name + '/DoWhile', bool_net, core.scoped_execution_step(
|
||||
_get_next_step_name('DoWhile-inner', name),
|
||||
nets_or_steps,
|
||||
should_stop_blob=stop_blob,
|
||||
))
|
||||
|
||||
|
||||
def DoUntil(name, condition_blob_or_net, nets_or_steps):
|
||||
"""
|
||||
Similar to DoWhile() but execute nets_or_steps when
|
||||
condition_blob_or_net returns false. It will execute
|
||||
nets_or_steps before evaluating condition_blob_or_net.
|
||||
|
||||
Special case: if condition_blob_or_net is a blob and is pre-set to
|
||||
true, then only the first net/step of nets_or_steps will be executed and
|
||||
loop is exited. So you need to be careful about the initial value the
|
||||
condition blob when using DoUntil(), esp when DoUntil() is called twice.
|
||||
"""
|
||||
if not isinstance(condition_blob_or_net, core.Net):
|
||||
stop_blob = core.BlobReference(condition_blob_or_net)
|
||||
return core.scoped_execution_step(
|
||||
_get_next_step_name('DoUntil', name),
|
||||
nets_or_steps,
|
||||
should_stop_blob=stop_blob)
|
||||
|
||||
nets_or_steps = _AppendNets(nets_or_steps, condition_blob_or_net)
|
||||
stop_blob = GetConditionBlobFromNet(condition_blob_or_net)
|
||||
|
||||
# If stop_blob is pre-set to True (this may happen when DoWhile() is
|
||||
# called twice), the loop will exit after executing the first net/step
|
||||
# in nets_or_steps. This is not what we want. So we use BootNet to
|
||||
# set stop_blob to False.
|
||||
bool_net = BoolNet((stop_blob, False))
|
||||
return Do(name + '/DoUntil', bool_net, core.scoped_execution_step(
|
||||
_get_next_step_name('DoUntil-inner', name),
|
||||
nets_or_steps,
|
||||
should_stop_blob=stop_blob,
|
||||
))
|
||||
|
||||
|
||||
def Switch(name, *conditions):
|
||||
"""
|
||||
Execute the steps for which the condition is true.
|
||||
Each condition is a tuple (condition_blob_or_net, nets_or_steps).
|
||||
Note:
|
||||
1. Multi steps can be executed if their conditions are true.
|
||||
2. The conditions_blob_or_net (if it is Net) of all steps will be
|
||||
executed once.
|
||||
|
||||
Examples:
|
||||
- Switch('name', (cond_1, net_1), (cond_2, net_2), ..., (cond_n, net_n))
|
||||
- Switch('name', [(cond_1, net1), (cond_2, net_2), ..., (cond_n, net_n)])
|
||||
- Switch('name', (cond_1, net_1))
|
||||
"""
|
||||
conditions = _MakeList(conditions)
|
||||
return core.scoped_execution_step(
|
||||
_get_next_step_name('Switch', name),
|
||||
[_RunOnceIf(name + '/Switch', cond, step) for cond, step in conditions])
|
||||
|
||||
|
||||
def SwitchNot(name, *conditions):
|
||||
"""
|
||||
Similar to Switch() but execute the steps for which the condition is False.
|
||||
"""
|
||||
conditions = _MakeList(conditions)
|
||||
return core.scoped_execution_step(
|
||||
_get_next_step_name('SwitchNot', name),
|
||||
[_RunOnceIfNot(name + '/SwitchNot', cond, step)
|
||||
for cond, step in conditions])
|
||||
|
||||
|
||||
def If(name, condition_blob_or_net,
|
||||
true_nets_or_steps, false_nets_or_steps=None):
|
||||
"""
|
||||
condition_blob_or_net is first evaluated or executed. If the condition is
|
||||
true, true_nets_or_steps is then executed, otherwise, false_nets_or_steps
|
||||
is executed.
|
||||
|
||||
If condition_blob_or_net is Net, the condition is its last external_output
|
||||
that must be a single bool. And this Net will be executred before both
|
||||
true/false_nets_or_steps so as to get the condition.
|
||||
"""
|
||||
if not false_nets_or_steps:
|
||||
return _RunOnceIf(name + '/If',
|
||||
condition_blob_or_net, true_nets_or_steps)
|
||||
|
||||
if isinstance(condition_blob_or_net, core.Net):
|
||||
condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
|
||||
else:
|
||||
condition_blob = condition_blob_or_net
|
||||
|
||||
return Do(
|
||||
name + '/If',
|
||||
_RunOnceIf(name + '/If-true',
|
||||
condition_blob_or_net, true_nets_or_steps),
|
||||
_RunOnceIfNot(name + '/If-false', condition_blob, false_nets_or_steps)
|
||||
)
|
||||
|
||||
|
||||
def IfNot(name, condition_blob_or_net,
|
||||
true_nets_or_steps, false_nets_or_steps=None):
|
||||
"""
|
||||
If condition_blob_or_net returns false, executes true_nets_or_steps,
|
||||
otherwise executes false_nets_or_steps
|
||||
"""
|
||||
if not false_nets_or_steps:
|
||||
return _RunOnceIfNot(name + '/IfNot',
|
||||
condition_blob_or_net, true_nets_or_steps)
|
||||
|
||||
if isinstance(condition_blob_or_net, core.Net):
|
||||
condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
|
||||
else:
|
||||
condition_blob = condition_blob_or_net
|
||||
|
||||
return Do(
|
||||
name + '/IfNot',
|
||||
_RunOnceIfNot(name + '/IfNot-true',
|
||||
condition_blob_or_net, true_nets_or_steps),
|
||||
_RunOnceIf(name + '/IfNot-false', condition_blob, false_nets_or_steps)
|
||||
)
|
||||
706
caffe2/python/control_ops_grad.py
Normal file
706
caffe2/python/control_ops_grad.py
Normal file
@ -0,0 +1,706 @@
|
||||
## @package control_ops_grad
|
||||
# Module caffe2.python.control_ops_grad
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.proto import caffe2_pb2
|
||||
|
||||
|
||||
def gen_do_gradient(op, g_output):
|
||||
"""
|
||||
Generates gradient Do operator, given forward Do op and a list
|
||||
of gradient blobs corresponding to forward op's outputs
|
||||
Returns a gradient op and a list of blobs corresponding to input gradients
|
||||
"""
|
||||
from caffe2.python.core import BlobReference
|
||||
subnet, outer_to_inner_map, inner_to_outer_map, workspace_blob_name = \
|
||||
_do_op_sanity_check_and_process(op)
|
||||
|
||||
assert len(g_output) == len(op.output), \
|
||||
"Different number of gradient blobs and Do op outputs"
|
||||
|
||||
grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
|
||||
g_output = deduped_g_output
|
||||
|
||||
# From the outer net point of view:
|
||||
# Do is an operator that has some number of inputs and outputs;
|
||||
# we have to generate a gradient operator that writes into
|
||||
# corresponding input gradient blobs and has access to inputs, outputs
|
||||
# and gradient output blobs
|
||||
# From the inner net point of view:
|
||||
# Do is an operator with a subnet and blob bindings,
|
||||
# we need to forward Do's output blob gradients into inner workspace,
|
||||
# use them to run backward pass generation and forward Do's input blob
|
||||
# gradients back into outer workspace
|
||||
|
||||
op_output = [str(o) for o in op.output]
|
||||
op_output = op_output[:-1] # remove workspace pointer blob
|
||||
op_input = [str(i) for i in op.input]
|
||||
op_input = op_input[:-1] # remove workspace pointer blob
|
||||
|
||||
ordered_inner_output_blob_names = [outer_to_inner_map[o] for o in op_output]
|
||||
|
||||
backward_pass_initial_grad_map = {}
|
||||
initial_grad_map = {}
|
||||
for inner_output_name, outer_grad_output_name in \
|
||||
zip(ordered_inner_output_blob_names, g_output):
|
||||
# link inner_output_name to corresponding inner_grad_output_name for
|
||||
# backward pass generation;
|
||||
if outer_grad_output_name:
|
||||
inner_grad_output_name = inner_output_name + "/_DO_OPERATOR_INNER_GRAD_"
|
||||
backward_pass_initial_grad_map[BlobReference(inner_output_name)] = \
|
||||
BlobReference(inner_grad_output_name)
|
||||
initial_grad_map[inner_grad_output_name] = str(outer_grad_output_name)
|
||||
assert len(initial_grad_map) > 0, "Empty initial gradient map for Do op"
|
||||
|
||||
inner_grad_ops, inner_grad_names_map = _gen_subgradient_pass(
|
||||
subnet, backward_pass_initial_grad_map)
|
||||
|
||||
if len(inner_grad_ops) == 0:
|
||||
return [], []
|
||||
|
||||
grad_copy_ops = []
|
||||
g_input = []
|
||||
new_op_outputs = []
|
||||
new_blob_bindings = {}
|
||||
for outer_input_name in op_input:
|
||||
inner_input_name = outer_to_inner_map[outer_input_name]
|
||||
if inner_input_name in inner_grad_names_map:
|
||||
inner_grad_input_name = inner_grad_names_map[inner_input_name]
|
||||
outer_grad_input_name = outer_input_name + "_grad"
|
||||
|
||||
# It is possible that inner_grad_input_name will need to be
|
||||
# linked to another outer blob. For example:
|
||||
#
|
||||
# // y - param initialized in init_net
|
||||
# x = ...
|
||||
# z = ...
|
||||
# with ops.IfNet(...):
|
||||
# ops.Add([z, x], y) # inner Do block
|
||||
# loss = f(..., y, ...)
|
||||
#
|
||||
# In this case x, y and z are external for the inner Do block,
|
||||
# the inputs of the Do block are z and x and the output is y.
|
||||
# When computing the gradient of input x given the gradient
|
||||
# of output y it's easy to see that they are equal.
|
||||
# During the generation of gradient Do operator, we link
|
||||
# external gradient y (y_grad) to the internal name
|
||||
# (y/_DO_OPERATOR_INNER_GRAD_) and generate the backward pass
|
||||
# for the internal Do net. As a result we get gradient operators
|
||||
# for the gradient Do and gradient map that maps internal Do
|
||||
# blobs to their computed gradients.
|
||||
# In this example, gradient map may have blob x linked to
|
||||
# gradient blob y/_DO_OPERATOR_INNER_GRAD_.
|
||||
# We should export gradient for x outside of Do, so
|
||||
# we add a blob mapping from inner gradient blob
|
||||
# (y/_DO_OPERATOR_INNER_GRAD_) to a new outer name (x_grad).
|
||||
#
|
||||
# (Note: since we use transparent blob mapping between outer and
|
||||
# inner (Do's) workspace, these operations do not involve copying
|
||||
# but are merely using blobs in outer workspace in the Do's operator
|
||||
# workspace under (possibly) different names)
|
||||
#
|
||||
# At the same time, we need to add a blob mapping from inner name
|
||||
# y/_DO_OPERATOR_INNER_GRAD_ to the outer blob y_grad
|
||||
# Hence in this case, we cannot use existing blob mapping scheme
|
||||
# that requires a bijection between subset of inner blob names and
|
||||
# a set of all (Do's input and output) outer blob names
|
||||
|
||||
# TODO(iliacher): Remove unnecessary blob copying
|
||||
|
||||
new_inner_grad_input_name = \
|
||||
inner_input_name + "/_DO_OPERATOR_INNER_GRAD_COPY_"
|
||||
grad_copy_ops.append(_prepare_blob_copy_op(
|
||||
inner_grad_input_name, new_inner_grad_input_name))
|
||||
|
||||
new_blob_bindings[new_inner_grad_input_name] = outer_grad_input_name
|
||||
new_op_outputs.append(outer_grad_input_name)
|
||||
g_input.append(outer_grad_input_name)
|
||||
else:
|
||||
g_input.append(None)
|
||||
|
||||
new_op_inputs = []
|
||||
overwritten_names = set()
|
||||
saved_local_blob_names = set()
|
||||
for grad_op in inner_grad_ops:
|
||||
grad_op_input = [str(i) for i in grad_op.input]
|
||||
grad_op_output = [str(o) for o in grad_op.output]
|
||||
for grad_op_input_name in grad_op_input:
|
||||
if grad_op_input_name in overwritten_names:
|
||||
continue
|
||||
# check if this is an external blob
|
||||
outer_name = inner_to_outer_map.get(grad_op_input_name, None)
|
||||
if not outer_name:
|
||||
# check if this is an external gradient blob
|
||||
outer_name = initial_grad_map.get(grad_op_input_name, None)
|
||||
if outer_name:
|
||||
outer_name = str(outer_name)
|
||||
if outer_name not in new_op_inputs:
|
||||
new_op_inputs.append(outer_name)
|
||||
|
||||
new_blob_bindings[grad_op_input_name] = outer_name
|
||||
else:
|
||||
# this is a local blob, we'll get it's value from
|
||||
# a saved forward op workspace
|
||||
saved_local_blob_names.add(grad_op_input_name)
|
||||
overwritten_names.update(grad_op_output)
|
||||
|
||||
# add inner gradient copy ops
|
||||
inner_grad_ops += grad_copy_ops
|
||||
|
||||
gradient_do_def = _prepare_gradient_do_op(
|
||||
fwd_op=op,
|
||||
fwd_net=subnet,
|
||||
grad_ops=inner_grad_ops,
|
||||
inputs=new_op_inputs,
|
||||
outputs=new_op_outputs,
|
||||
blob_bindings=new_blob_bindings,
|
||||
saved_fwd_blobs=saved_local_blob_names,
|
||||
workspace_blob_name=workspace_blob_name)
|
||||
grad_ops.append(gradient_do_def)
|
||||
|
||||
_do_op_sanity_check_and_process(gradient_do_def)
|
||||
|
||||
return grad_ops, g_input
|
||||
|
||||
|
||||
def dedupe_g_output(op, g_output):
|
||||
# When generation a gradient op it's possible to receive the same gradient
|
||||
# blob corresponding to different forward op output blobs, Do operator
|
||||
# requires a bijection between inner and outer names, make sure we do
|
||||
# deduplication
|
||||
grad_ops = []
|
||||
deduped_g_output = []
|
||||
init_grad_map = {}
|
||||
for output_name, grad_name in zip(op.output, g_output):
|
||||
if not grad_name:
|
||||
deduped_g_output.append(grad_name)
|
||||
continue
|
||||
|
||||
if output_name in init_grad_map:
|
||||
deduped_g_output.append(init_grad_map[output_name])
|
||||
else:
|
||||
if grad_name not in init_grad_map.values():
|
||||
init_grad_map[output_name] = grad_name
|
||||
deduped_g_output.append(grad_name)
|
||||
else:
|
||||
deduped_grad_name = output_name + "_" + grad_name + "_DEDUP"
|
||||
assert deduped_grad_name not in init_grad_map.values()
|
||||
grad_copy_op = caffe2_pb2.OperatorDef()
|
||||
grad_copy_op.type = "Copy"
|
||||
grad_copy_op.input.extend([grad_name])
|
||||
grad_copy_op.output.extend([deduped_grad_name])
|
||||
grad_ops.append(grad_copy_op)
|
||||
deduped_g_output.append(deduped_grad_name)
|
||||
init_grad_map[output_name] = deduped_grad_name
|
||||
return grad_ops, deduped_g_output
|
||||
|
||||
|
||||
def gen_while_gradient(op, g_output):
|
||||
"""
|
||||
Generates gradient While operator
|
||||
"""
|
||||
from caffe2.python.core import BlobReference
|
||||
assert op.type == "While", "Expected While op"
|
||||
assert len(op.input) > 0, "Expected at least one input in While op"
|
||||
|
||||
assert len(op.output) == len(g_output), \
|
||||
"Different number of gradient blobs and While op outputs"
|
||||
|
||||
grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
|
||||
g_output = deduped_g_output
|
||||
|
||||
init_grad_map = {}
|
||||
op_output = [str(o) for o in op.output]
|
||||
for output_name, grad_output_name in zip(op_output, g_output):
|
||||
if grad_output_name:
|
||||
init_grad_map[BlobReference(output_name)] = \
|
||||
BlobReference(grad_output_name)
|
||||
assert len(init_grad_map) > 0, "Empty initial gradient map for While op"
|
||||
|
||||
loop_net = _get_net_argument(op, "loop_net")
|
||||
assert loop_net, "Expected loop subnet in While op"
|
||||
assert len(loop_net.op) == 1 and loop_net.op[0].type == "Do", \
|
||||
"Gradient While op requires single Do op as a loop body"
|
||||
do_op = loop_net.op[0]
|
||||
do_args = _get_do_arguments(do_op)
|
||||
assert "reuse_workspace" not in do_args or not do_args["reuse_workspace"], \
|
||||
"Gradient While op requires Do loop body op without reuse_workspace set"
|
||||
|
||||
assert len(do_op.output) > 0, "Expected Do op with at least one output"
|
||||
workspace_blob = do_op.output[-1]
|
||||
|
||||
loop_grad_net, loop_grad_map, loop_input_names, loop_output_names = \
|
||||
_gen_subnet_gradient(loop_net, init_grad_map)
|
||||
assert loop_grad_net, "Failed to get gradient net for loop body in While op"
|
||||
|
||||
grad_ops += _prepare_gradient_while_ops(
|
||||
fwd_op=op,
|
||||
input_names=loop_input_names,
|
||||
output_names=loop_output_names,
|
||||
loop_grad_net=loop_grad_net,
|
||||
workspace_blob=workspace_blob,
|
||||
init_grad_map=init_grad_map,
|
||||
loop_grad_map=loop_grad_map)
|
||||
|
||||
op_input = [str(i) for i in op.input]
|
||||
g_input = [loop_grad_map.get(i, None) for i in op_input]
|
||||
return grad_ops, g_input
|
||||
|
||||
|
||||
# Constructs gradient While op, arguments:
|
||||
# fwd_op - forward While op
|
||||
# input_names - input blob names for a gradient op
|
||||
# output_names - output blob names for a gradient op
|
||||
# loop_grad_net - gradient loop body net
|
||||
# workspace_blob - blob that holds forward workspaces stack
|
||||
# init_grad_map - initial gradient to forward blob map
|
||||
# loop_grad_map - gradient blob map for loop's body
|
||||
def _prepare_gradient_while_ops(
|
||||
fwd_op, input_names, output_names, loop_grad_net, workspace_blob,
|
||||
init_grad_map, loop_grad_map):
|
||||
gradient_while_def = caffe2_pb2.OperatorDef()
|
||||
gradient_while_def.CopyFrom(fwd_op)
|
||||
if gradient_while_def.name:
|
||||
gradient_while_def.name += "_grad"
|
||||
|
||||
loop_net_arg = caffe2_pb2.Argument()
|
||||
loop_net_arg.name = "loop_net"
|
||||
loop_net_arg.n.CopyFrom(loop_grad_net)
|
||||
|
||||
cond_net_arg = caffe2_pb2.Argument()
|
||||
cond_net_arg.name = "cond_net"
|
||||
from caffe2.python.core import Net, BlobReference
|
||||
# Construct condition net - check that there're still forward workspaces
|
||||
# left using HasScope op
|
||||
cond_net = Net('gradient_loop_cond_net')
|
||||
cond_init_net = Net('gradient_loop_cond_net_init')
|
||||
cond_blob = cond_net.NextScopedBlob(cond_net.Name() + '/cond')
|
||||
cond_init_net.HasScope(workspace_blob, cond_blob)
|
||||
cond_net.HasScope(workspace_blob, cond_blob)
|
||||
for blob, init_grad_blob in init_grad_map.items():
|
||||
blob_name = str(blob)
|
||||
init_grad_blob_name = str(init_grad_blob)
|
||||
if blob_name in loop_grad_map and \
|
||||
loop_grad_map[blob_name] != init_grad_blob_name:
|
||||
cond_net.Copy(
|
||||
BlobReference(loop_grad_map[blob_name]), init_grad_blob)
|
||||
cond_init_net.Copy(
|
||||
init_grad_blob, BlobReference(loop_grad_map[blob_name]))
|
||||
cond_net_arg.n.CopyFrom(cond_net.Proto())
|
||||
|
||||
del gradient_while_def.arg[:]
|
||||
gradient_while_def.arg.extend([loop_net_arg, cond_net_arg])
|
||||
|
||||
del gradient_while_def.control_input[:]
|
||||
del gradient_while_def.input[:]
|
||||
gradient_while_def.input.extend(
|
||||
[str(cond_blob).encode('utf-8')] + list(input_names))
|
||||
del gradient_while_def.output[:]
|
||||
gradient_while_def.output.extend(output_names)
|
||||
gradient_while_def.is_gradient_op = True
|
||||
return [o for o in cond_init_net.Proto().op] + [gradient_while_def]
|
||||
|
||||
|
||||
def _get_do_arguments(do_op):
|
||||
assert do_op.type == "Do", "Expected Do op"
|
||||
args = {}
|
||||
for arg in do_op.arg:
|
||||
if not arg.name:
|
||||
continue
|
||||
if arg.name == "net":
|
||||
assert arg.n, "Expected non empty net argument"
|
||||
args["net"] = arg.n
|
||||
elif arg.name == "reuse_workspace":
|
||||
assert arg.i, "Expected non empty reuse_workspace argument"
|
||||
args["reuse_workspace"] = bool(arg.i)
|
||||
elif arg.name == "inner_blobs":
|
||||
assert arg.strings, "Expected non empty inner_blobs argument"
|
||||
args["inner_blobs"] = arg.strings
|
||||
elif arg.name == "outer_blobs_idx":
|
||||
assert arg.ints, "Expected non empty outer_blobs_idx argument"
|
||||
args["outer_blobs_idx"] = arg.ints
|
||||
return args
|
||||
|
||||
|
||||
def gen_if_gradient(op, g_output):
|
||||
"""
|
||||
Generates gradient If operator, given forward If op and a list
|
||||
of gradient blobs corresponding to forward op's outputs
|
||||
Returns a gradient op and a list of blobs corresponding to input gradients
|
||||
"""
|
||||
from caffe2.python.core import BlobReference
|
||||
assert op.type == "If", "Expected If op"
|
||||
# first input is the condition blob
|
||||
assert len(op.input) > 0, "Expected at least one input in If op"
|
||||
|
||||
assert len(op.output) == len(g_output), \
|
||||
"Different number of gradient blobs and If op outputs"
|
||||
|
||||
grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
|
||||
g_output = deduped_g_output
|
||||
|
||||
init_grad_map = {} # map from if's output blob to output gradient blob
|
||||
op_input = [str(i) for i in op.input]
|
||||
op_output = [str(o) for o in op.output]
|
||||
for output_name, grad_output_name in zip(op_output, g_output):
|
||||
if grad_output_name:
|
||||
init_grad_map[BlobReference(output_name)] = \
|
||||
BlobReference(grad_output_name)
|
||||
# shouldn't call without at least one output gradient available
|
||||
assert len(init_grad_map) > 0, "Empty initial gradient map for If op"
|
||||
|
||||
grad_map = {} # map from blob to gradient blob
|
||||
then_net = _get_net_argument(op, "then_net")
|
||||
assert then_net, "Expected then subnet in If op"
|
||||
then_grad_net, then_grad_map, then_input_names, then_output_names = \
|
||||
_gen_subnet_gradient(then_net, init_grad_map)
|
||||
assert then_grad_net, "Failed to get gradient net for then in If op"
|
||||
grad_map.update(then_grad_map)
|
||||
|
||||
else_input_names = set()
|
||||
else_output_names = set()
|
||||
else_grad_map = {}
|
||||
else_grad_net = None
|
||||
else_net = _get_net_argument(op, "else_net")
|
||||
if else_net:
|
||||
else_grad_net, else_grad_map, else_input_names, else_output_names = \
|
||||
_gen_subnet_gradient(else_net, init_grad_map)
|
||||
assert else_grad_net, "Failed to get gradient net for else in If op"
|
||||
# consider case: else doesn't update blob's gradient and keeps original
|
||||
# from init_grad_map, but then updates the gradient
|
||||
for else_blob, else_grad_blob in else_grad_map.items():
|
||||
if else_blob in then_grad_map:
|
||||
then_grad_blob = then_grad_map[else_blob]
|
||||
# if both then and else branches have grad blob name for the same
|
||||
# blob and grad names are different, then one of the branches
|
||||
# doesn't use blob and has original grad blob name in it's grad map,
|
||||
# and another branch uses blob and has <blob_name>_grad name
|
||||
# in it's grad map (might be different from original grad blob)
|
||||
if then_grad_blob != else_grad_blob:
|
||||
init_grad_name = init_grad_map[else_blob] \
|
||||
if else_blob in init_grad_map else None
|
||||
|
||||
if then_grad_blob == init_grad_name:
|
||||
grad_map[else_blob] = else_grad_blob
|
||||
elif else_grad_blob == init_grad_name:
|
||||
grad_map[else_blob] = then_grad_blob
|
||||
else:
|
||||
raise "Unexpected grad blob name " + else_blob + ", " + \
|
||||
else_grad_blob + ", " + then_grad_blob
|
||||
else:
|
||||
grad_map[else_blob] = else_grad_blob
|
||||
|
||||
# make sure gradients of blobs that were not computed
|
||||
# by the selected if's branch are initialized with zeros
|
||||
then_other_output_names = \
|
||||
then_output_names - (then_output_names & else_output_names)
|
||||
then_other_grad_output_names = set(
|
||||
[o for o in then_other_output_names if o in then_grad_map.values()])
|
||||
zero_then = _gen_grad_zero_init_ops(
|
||||
init_grad_map, then_grad_map, then_other_grad_output_names)
|
||||
if else_grad_net:
|
||||
else_grad_net.op.extend(zero_then)
|
||||
elif len(zero_then) > 0:
|
||||
else_grad_net = caffe2_pb2.NetDef()
|
||||
else_grad_net.CopyFrom(then_grad_net)
|
||||
if else_grad_net.name:
|
||||
else_grad_net.name += "_auto_else_zero_blobs_"
|
||||
del else_grad_net.op[:]
|
||||
else_grad_net.op.extend(zero_then)
|
||||
del else_grad_net.external_input[:]
|
||||
del else_grad_net.external_output[:]
|
||||
|
||||
else_other_output_names = \
|
||||
else_output_names - (then_output_names & else_output_names)
|
||||
else_other_grad_output_names = set(
|
||||
[o for o in else_other_output_names if o in else_grad_map.values()])
|
||||
zero_else = _gen_grad_zero_init_ops(
|
||||
init_grad_map, else_grad_map, else_other_grad_output_names)
|
||||
then_grad_net.op.extend(zero_else)
|
||||
|
||||
output_names = list(then_output_names | else_output_names)
|
||||
input_names = then_input_names | else_input_names
|
||||
# make sure condition blob is the first in the list
|
||||
input_names = [op_input[0]] + list(input_names - set(op_input[0]))
|
||||
gradient_if_def = _prepare_gradient_if_op(
|
||||
fwd_op=op,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
then_grad_net=then_grad_net,
|
||||
else_grad_net=else_grad_net)
|
||||
g_input = [grad_map.get(i, None) for i in op_input]
|
||||
return grad_ops + [gradient_if_def], g_input
|
||||
|
||||
|
||||
def _gen_subnet_gradient(subnet, init_grad):
|
||||
grad_ops, grad_names_map = _gen_subgradient_pass(
|
||||
subnet, init_grad)
|
||||
|
||||
output_names = set()
|
||||
input_names = set()
|
||||
for grad_op in grad_ops:
|
||||
for grad_op_input in grad_op.input:
|
||||
if str(grad_op_input) not in output_names:
|
||||
input_names.add(str(grad_op_input))
|
||||
for grad_op_output in grad_op.output:
|
||||
output_names.add(str(grad_op_output))
|
||||
|
||||
gradient_net_def = caffe2_pb2.NetDef()
|
||||
gradient_net_def.CopyFrom(subnet)
|
||||
if gradient_net_def.name:
|
||||
gradient_net_def.name += "_grad"
|
||||
del gradient_net_def.op[:]
|
||||
gradient_net_def.op.extend(grad_ops)
|
||||
del gradient_net_def.external_input[:]
|
||||
del gradient_net_def.external_output[:]
|
||||
|
||||
return gradient_net_def, grad_names_map, input_names, output_names
|
||||
|
||||
|
||||
def _get_net_argument(op, net_name):
|
||||
for arg in op.arg:
|
||||
if arg.name and arg.name == net_name:
|
||||
assert arg.n, "Expected non empty net argument " + net_name
|
||||
return arg.n
|
||||
return None
|
||||
|
||||
|
||||
def getNetArgument(op, net_name):
|
||||
"""A wrapper for external call"""
|
||||
return _get_net_argument(op, net_name)
|
||||
|
||||
|
||||
def _gen_subgradient_pass(subnet, init_grad):
|
||||
from caffe2.python.core import IR
|
||||
subnet_ir = IR(subnet.op)
|
||||
grad_ops, grad_blob_map = \
|
||||
subnet_ir.GetBackwardPass(init_grad)
|
||||
grad_names_map = {}
|
||||
for b, g in grad_blob_map.items():
|
||||
grad_names_map[str(b)] = str(g)
|
||||
return grad_ops, grad_names_map
|
||||
|
||||
|
||||
def _do_op_sanity_check_and_process(op):
|
||||
assert op.type == "Do", "Expected Do op"
|
||||
|
||||
subnet = _get_net_argument(op, "net")
|
||||
assert subnet, "No net argument found in Do op"
|
||||
|
||||
inner_blobs = None
|
||||
outer_blobs_idx = None
|
||||
for arg in op.arg:
|
||||
if arg.name and arg.name == "inner_blobs":
|
||||
assert not inner_blobs, "inner_blobs redefinition"
|
||||
assert arg.strings and len(arg.strings) > 0, \
|
||||
"Empty inner_blobs argument in Do op"
|
||||
inner_blobs = [s.decode('utf-8') for s in arg.strings]
|
||||
if arg.name and arg.name == "outer_blobs_idx":
|
||||
assert not outer_blobs_idx, "outer_blobs_idx redefinition"
|
||||
assert arg.ints and len(arg.ints) > 0, \
|
||||
"Empty outer_blobs_idx argument in Do op"
|
||||
outer_blobs_idx = arg.ints
|
||||
if inner_blobs and outer_blobs_idx:
|
||||
break
|
||||
|
||||
assert inner_blobs, "No inner_blobs argument found in Do op"
|
||||
assert outer_blobs_idx, "No outer_blobs_idx argument found in Do op"
|
||||
|
||||
assert len(inner_blobs) == len(outer_blobs_idx), \
|
||||
"Arguments inner_blobs and outer_blobs_idx of different length in Do op"
|
||||
|
||||
all_inner_blobs = set(inner_blobs)
|
||||
assert len(all_inner_blobs) == len(inner_blobs), \
|
||||
"Found duplicates in inner_blobs in Do op"
|
||||
|
||||
op_input = [str(i) for i in op.input]
|
||||
assert len(op_input) > 0, "Expected at least one input blob"
|
||||
# remove last input blob that holds pointer to workspace
|
||||
input_workspace_blob_name = op_input[-1]
|
||||
op_input = op_input[:-1]
|
||||
|
||||
op_output = [str(o) for o in op.output]
|
||||
assert len(op_output) > 0, "Expected at least one output blob"
|
||||
# remove last output blob that holds pointer to workspace
|
||||
workspace_blob_name = op_output[-1]
|
||||
assert input_workspace_blob_name == workspace_blob_name, \
|
||||
"Expected same input/output workspace blob"
|
||||
op_output = op_output[:-1]
|
||||
|
||||
all_op_input_blob_names = set(op_input)
|
||||
assert len(all_op_input_blob_names) == len(op_input), \
|
||||
"Found duplicates in Do op inputs"
|
||||
all_op_output_blob_names = set(op_output)
|
||||
assert len(all_op_output_blob_names) == len(op_output), \
|
||||
"Found duplicates in Do op outputs"
|
||||
|
||||
ordered_outer_blob_names = op_input + op_output
|
||||
all_outer_blob_names = set(ordered_outer_blob_names)
|
||||
used_outer_blob_names = set()
|
||||
outer_to_inner_map = {}
|
||||
inner_to_outer_map = {}
|
||||
for inner_name, outer_blob_idx in zip(inner_blobs, outer_blobs_idx):
|
||||
assert outer_blob_idx >= 0 and \
|
||||
outer_blob_idx < len(ordered_outer_blob_names), \
|
||||
"Outer blob index is out of bounds in Do op"
|
||||
outer_name = ordered_outer_blob_names[outer_blob_idx]
|
||||
assert outer_name not in used_outer_blob_names, \
|
||||
"Reusage of outer blob name " + outer_name + " in Do op"
|
||||
used_outer_blob_names.add(outer_name)
|
||||
outer_to_inner_map[outer_name] = inner_name
|
||||
inner_to_outer_map[inner_name] = outer_name
|
||||
|
||||
assert len(used_outer_blob_names) == len(all_outer_blob_names), \
|
||||
"Not all outer blob names are used in blob bindings in Do op"
|
||||
|
||||
return subnet, outer_to_inner_map, inner_to_outer_map, workspace_blob_name
|
||||
|
||||
|
||||
def _prepare_blob_copy_op(from_name, to_name):
|
||||
copy_op_def = caffe2_pb2.OperatorDef()
|
||||
copy_op_def.type = "Copy"
|
||||
copy_op_def.input.extend([from_name])
|
||||
copy_op_def.output.extend([to_name])
|
||||
return copy_op_def
|
||||
|
||||
|
||||
def _prepare_gradient_do_op(
|
||||
fwd_op, fwd_net, grad_ops, inputs, outputs, blob_bindings, saved_fwd_blobs,
|
||||
workspace_blob_name):
|
||||
gradient_net_def = caffe2_pb2.NetDef()
|
||||
gradient_net_def.CopyFrom(fwd_net)
|
||||
if gradient_net_def.name:
|
||||
gradient_net_def.name += "_grad"
|
||||
del gradient_net_def.op[:]
|
||||
gradient_net_def.op.extend(grad_ops)
|
||||
del gradient_net_def.external_input[:]
|
||||
del gradient_net_def.external_output[:]
|
||||
|
||||
gradient_do_def = caffe2_pb2.OperatorDef()
|
||||
gradient_do_def.CopyFrom(fwd_op)
|
||||
if gradient_do_def.name and len(gradient_do_def.name) > 0:
|
||||
gradient_do_def.name += "_grad"
|
||||
|
||||
del gradient_do_def.input[:]
|
||||
gradient_do_def.input.extend(inputs)
|
||||
# workspace pointer blob
|
||||
gradient_do_def.input.append(workspace_blob_name)
|
||||
del gradient_do_def.output[:]
|
||||
gradient_do_def.output.extend(outputs)
|
||||
# workspace pointer blob
|
||||
gradient_do_def.output.append(workspace_blob_name)
|
||||
|
||||
net_arg = caffe2_pb2.Argument()
|
||||
net_arg.name = "net"
|
||||
net_arg.n.CopyFrom(gradient_net_def)
|
||||
|
||||
ordered_new_outer_names = inputs + outputs
|
||||
inner_blobs = blob_bindings.keys()
|
||||
new_outer_blobs_idx = [ordered_new_outer_names.index(blob_bindings[b])
|
||||
for b in inner_blobs]
|
||||
|
||||
inner_blobs_arg = caffe2_pb2.Argument()
|
||||
inner_blobs_arg.name = "inner_blobs"
|
||||
inner_blobs_arg.strings.extend([b.encode('utf-8') for b in inner_blobs])
|
||||
|
||||
outer_blobs_idx_arg = caffe2_pb2.Argument()
|
||||
outer_blobs_idx_arg.name = "outer_blobs_idx"
|
||||
outer_blobs_idx_arg.ints.extend(new_outer_blobs_idx)
|
||||
|
||||
saved_blobs_arg = caffe2_pb2.Argument()
|
||||
saved_blobs_arg.name = "saved_fwd_blobs"
|
||||
saved_blobs_arg.strings.extend(
|
||||
[b.encode('utf-8') for b in saved_fwd_blobs])
|
||||
|
||||
del gradient_do_def.arg[:]
|
||||
gradient_do_def.arg.extend([
|
||||
net_arg, inner_blobs_arg, outer_blobs_idx_arg, saved_blobs_arg])
|
||||
del gradient_do_def.control_input[:]
|
||||
|
||||
gradient_do_def.is_gradient_op = True
|
||||
|
||||
return gradient_do_def
|
||||
|
||||
|
||||
def _gen_grad_zero_init_ops(init_grad_map, grad_map, grad_output_names):
|
||||
grad_init_ops = []
|
||||
for grad_output in grad_output_names:
|
||||
# get the corresponding output name blob and use it in ConstantFill
|
||||
# so that grad_output has the same shape
|
||||
output_name = None
|
||||
for o, g in grad_map.items():
|
||||
if g == grad_output:
|
||||
output_name = o
|
||||
break
|
||||
assert output_name, "Unknown gradient output " + grad_output
|
||||
|
||||
grad_init_op = None
|
||||
# make sure that we do not overwrite existing gradients with zeros
|
||||
if output_name in init_grad_map:
|
||||
init_grad_name = init_grad_map[output_name]
|
||||
# in case we use a different gradient blob name, copy gradient
|
||||
if init_grad_name != grad_output:
|
||||
grad_init_op = caffe2_pb2.OperatorDef()
|
||||
grad_init_op.type = "Copy"
|
||||
grad_init_op.input.extend([str(init_grad_name)])
|
||||
grad_init_op.output.extend([str(grad_output)])
|
||||
else:
|
||||
grad_init_op = caffe2_pb2.OperatorDef()
|
||||
grad_init_op.type = "ConstantFill"
|
||||
grad_init_op.input.extend([output_name])
|
||||
grad_init_op.output.extend([grad_output])
|
||||
value_arg = caffe2_pb2.Argument()
|
||||
value_arg.name = "value"
|
||||
value_arg.f = 0.0
|
||||
grad_init_op.arg.extend([value_arg])
|
||||
|
||||
if grad_init_op:
|
||||
grad_init_ops.append(grad_init_op)
|
||||
return grad_init_ops
|
||||
|
||||
|
||||
def _prepare_gradient_if_op(
|
||||
fwd_op, input_names, output_names, then_grad_net, else_grad_net):
|
||||
gradient_if_def = caffe2_pb2.OperatorDef()
|
||||
gradient_if_def.CopyFrom(fwd_op)
|
||||
del gradient_if_def.input[:]
|
||||
gradient_if_def.input.extend(input_names)
|
||||
del gradient_if_def.output[:]
|
||||
gradient_if_def.output.extend(output_names)
|
||||
|
||||
then_net_arg = caffe2_pb2.Argument()
|
||||
then_net_arg.name = "then_net"
|
||||
then_net_arg.n.CopyFrom(then_grad_net)
|
||||
gradient_args = [then_net_arg]
|
||||
if else_grad_net:
|
||||
else_net_arg = caffe2_pb2.Argument()
|
||||
else_net_arg.name = "else_net"
|
||||
else_net_arg.n.CopyFrom(else_grad_net)
|
||||
gradient_args.append(else_net_arg)
|
||||
|
||||
del gradient_if_def.arg[:]
|
||||
gradient_if_def.arg.extend(gradient_args)
|
||||
if gradient_if_def.name:
|
||||
gradient_if_def.name += "_grad"
|
||||
del gradient_if_def.control_input[:]
|
||||
gradient_if_def.is_gradient_op = True
|
||||
return gradient_if_def
|
||||
|
||||
|
||||
def disambiguate_grad_if_op_output(grad_op, idx, new_grad_output):
|
||||
then_net = _get_net_argument(grad_op, "then_net")
|
||||
old_grad_out_match = grad_op.output[idx]
|
||||
for op in then_net.op:
|
||||
for i, out in enumerate(op.output):
|
||||
if out == old_grad_out_match:
|
||||
op.output[i] = new_grad_output
|
||||
else_net = _get_net_argument(grad_op, "else_net")
|
||||
if else_net:
|
||||
for op in else_net.op:
|
||||
for i, out in enumerate(op.output):
|
||||
if out == old_grad_out_match:
|
||||
op.output[i] = new_grad_output
|
||||
grad_op.output[idx] = new_grad_output
|
||||
49
caffe2/python/control_ops_grad_test.py
Normal file
49
caffe2/python/control_ops_grad_test.py
Normal file
@ -0,0 +1,49 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import unittest
|
||||
from caffe2.python import core, test_util, workspace
|
||||
from caffe2.python.control_ops_grad import disambiguate_grad_if_op_output
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestControl(test_util.TestCase):
|
||||
def test_disambiguate_grad_if_op_output(self):
|
||||
workspace.FeedBlob("cond", np.array(True))
|
||||
workspace.FeedBlob("then_grad", np.array(1))
|
||||
workspace.FeedBlob("else_grad", np.array(2))
|
||||
|
||||
then_model = ModelHelper(name="then_test_model")
|
||||
then_model.net.Copy("then_grad", "input_grad")
|
||||
|
||||
else_model = ModelHelper(name="else_test_model")
|
||||
else_model.net.Copy("else_grad", "else_temp_grad")
|
||||
else_model.net.Copy("else_temp", "input_grad")
|
||||
|
||||
# to BuildGradientGenerators, in forward pass, we need else temp
|
||||
# as one of the output. Which later on results in a grad op like this:
|
||||
grad_op = core.CreateOperator(
|
||||
"If",
|
||||
["cond", "then_grad", "else_grad"],
|
||||
["input_grad", "else_temp_grad"],
|
||||
then_net=then_model.net.Proto(),
|
||||
else_net=else_model.net.Proto(),
|
||||
)
|
||||
|
||||
# in certain cases, another branch of the net also generates input_grad
|
||||
# and we call _DisambiguateGradOpOutput in core.py
|
||||
new_grad_output = "input_grad" + "_autosplit_" + "0"
|
||||
disambiguate_grad_if_op_output(grad_op, 0, new_grad_output)
|
||||
self.assertEqual(grad_op.output[0], new_grad_output)
|
||||
for arg in grad_op.arg:
|
||||
if arg.name == "else_net":
|
||||
self.assertEqual(arg.n.op[1].output[0], new_grad_output)
|
||||
else:
|
||||
self.assertEqual(arg.name, "then_net")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
263
caffe2/python/control_ops_util.py
Normal file
263
caffe2/python/control_ops_util.py
Normal file
@ -0,0 +1,263 @@
|
||||
## @package control_ops_util
|
||||
# Module caffe2.python.control_ops_util
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import core
|
||||
|
||||
|
||||
def get_external_blob_names(net, lexical_scope):
|
||||
"""
|
||||
Returns a set of blobs a given net depends on and a set of
|
||||
output blobs that are written by the net
|
||||
Inputs:
|
||||
net - net to return input/output blobs for;
|
||||
lexical_scope - all external blob names visible to the net
|
||||
"""
|
||||
# Use the blobs that are actually read/written to as external inputs/outputs
|
||||
net_proto = net.Proto()
|
||||
net_ssa, _ = core.get_ssa(net_proto)
|
||||
input_names = core.get_undefined_blobs(net_ssa)
|
||||
for input_name in input_names:
|
||||
assert str(input_name) in lexical_scope, \
|
||||
"Input blob " + input_name + " is undefined"
|
||||
|
||||
output_names = set()
|
||||
for op in net_proto.op:
|
||||
for output in op.output:
|
||||
if output in lexical_scope:
|
||||
output_names.add(output)
|
||||
|
||||
return input_names, output_names
|
||||
|
||||
|
||||
def add_if_op(if_net, cond_blob, lexical_scope, then_net, else_net=None):
|
||||
"""
|
||||
A helper function to add an If op to the net.
|
||||
Automatically determines whether blobs in the then/else subnets are external
|
||||
(from the outer workspace) or local (visible only inside subnet's workspace)
|
||||
based on lexical scope - set of all outer blob names visible to the 'If'
|
||||
operator. All the blobs in then/else subnets with names matching a name in lexical
|
||||
scope and all the blobs that are first used as the operators' inputs are
|
||||
considered outer blobs - these blobs must exist in the outer workspace,
|
||||
then/else subnets can read their values and new values written into these blobs
|
||||
will be visible outside of the 'If' operator. All other blobs are local - exist
|
||||
only within inner workspaces for then/else.
|
||||
Inputs:
|
||||
if_net - net to add an If op to;
|
||||
cond_blob - scalar bool blob reference, used as If condition;
|
||||
lexical_scope - a set of outer blob names visible to then/else branches;
|
||||
then_net/else_net - nets (core.Net) for then/else branches
|
||||
"""
|
||||
then_input_blob_names, then_output_blob_names = get_external_blob_names(
|
||||
then_net, lexical_scope)
|
||||
|
||||
else_input_blob_names = set()
|
||||
else_output_blob_names = set()
|
||||
if else_net:
|
||||
else_input_blob_names, else_output_blob_names = get_external_blob_names(
|
||||
else_net, lexical_scope)
|
||||
|
||||
input_blob_names = then_input_blob_names | else_input_blob_names
|
||||
output_blob_names = then_output_blob_names | else_output_blob_names
|
||||
|
||||
if_inputs = [cond_blob]
|
||||
if_inputs += [core.BlobReference(name=b, net=None) for b in input_blob_names]
|
||||
if_outputs = [core.BlobReference(name=b, net=None) for b in output_blob_names]
|
||||
|
||||
do_then_net = core.Net('do_then_net')
|
||||
|
||||
then_input_blobs = \
|
||||
[core.BlobReference(name=b, net=None) for b in then_input_blob_names]
|
||||
then_output_blobs = \
|
||||
[core.BlobReference(name=b, net=None) for b in then_output_blob_names]
|
||||
then_input_output_names_ordered = [
|
||||
str(b) for b in (then_input_blobs + then_output_blobs)]
|
||||
|
||||
then_outer_blob_names = list(then_input_blob_names | then_output_blob_names)
|
||||
then_outer_blob_names_idx = [
|
||||
then_input_output_names_ordered.index(b) for b in then_outer_blob_names]
|
||||
|
||||
# make sure to use net's name to have unique blob name across multiple subnets
|
||||
do_then_workspace_blob = if_net.NextScopedBlob(if_net.Name() + '/workspace_if_then')
|
||||
then_input_blobs.append(do_then_workspace_blob)
|
||||
then_output_blobs.append(do_then_workspace_blob)
|
||||
# make sure that added workspace pointer blobs are in if inputs/outputs
|
||||
if_inputs.append(do_then_workspace_blob)
|
||||
if_outputs.append(do_then_workspace_blob)
|
||||
|
||||
do_then_net.Do(
|
||||
then_input_blobs,
|
||||
then_output_blobs,
|
||||
net=then_net.Proto(),
|
||||
inner_blobs=then_outer_blob_names,
|
||||
outer_blobs_idx=then_outer_blob_names_idx)
|
||||
do_then_net.AddExternalOutput(*then_output_blobs)
|
||||
|
||||
if_args = {}
|
||||
if_args['then_net'] = do_then_net.Proto()
|
||||
|
||||
do_else_workspace_blob = None
|
||||
if else_net:
|
||||
do_else_net = core.Net('do_else_net')
|
||||
|
||||
else_input_blobs = \
|
||||
[core.BlobReference(name=b, net=None) for b in else_input_blob_names]
|
||||
else_output_blobs = \
|
||||
[core.BlobReference(name=b, net=None) for b in else_output_blob_names]
|
||||
else_input_output_names_ordered = [
|
||||
str(b) for b in (else_input_blobs + else_output_blobs)]
|
||||
|
||||
else_outer_blob_names = list(else_input_blob_names | else_output_blob_names)
|
||||
else_outer_blob_names_idx = [
|
||||
else_input_output_names_ordered.index(b) for b in else_outer_blob_names]
|
||||
|
||||
do_else_workspace_blob = \
|
||||
if_net.NextScopedBlob(if_net.Name() + '/workspace_if_else')
|
||||
else_input_blobs.append(do_else_workspace_blob)
|
||||
else_output_blobs.append(do_else_workspace_blob)
|
||||
# make sure that added workspace pointer blobs are in if inputs/outputs
|
||||
if_inputs.append(do_else_workspace_blob)
|
||||
if_outputs.append(do_else_workspace_blob)
|
||||
|
||||
do_else_net.Do(
|
||||
else_input_blobs,
|
||||
else_output_blobs,
|
||||
net=else_net.Proto(),
|
||||
inner_blobs=else_outer_blob_names,
|
||||
outer_blobs_idx=else_outer_blob_names_idx)
|
||||
do_else_net.AddExternalOutput(*else_output_blobs)
|
||||
if_args['else_net'] = do_else_net.Proto()
|
||||
|
||||
if_net.CreateScope([], [do_then_workspace_blob])
|
||||
if do_else_workspace_blob:
|
||||
if_net.CreateScope([], [do_else_workspace_blob])
|
||||
if_net.If(if_inputs, if_outputs, **if_args)
|
||||
if_net.AddExternalOutput(*if_outputs)
|
||||
|
||||
|
||||
def add_while_op(
|
||||
while_net, cond_blob, lexical_scope, loop_body_net, condition_body_net=None):
|
||||
"""
|
||||
A helper function to add a While op to the net. Same rules for determining
|
||||
outer and inner blobs as for the 'If' operator apply for the 'While' operator
|
||||
loop and condition subnets. If specified, condition net is executed in a separate
|
||||
workspace before the first and after each iteration, the last operator must have
|
||||
a single scalar boolean output that is written into the condition blob.
|
||||
Inputs:
|
||||
while_net - net to add a While op to;
|
||||
cond_blob - scalar bool blob reference, used as a stop condition;
|
||||
lexical_scope - a set of outer blob names visible to the loop's body;
|
||||
loop_body_net - net to execute on each iteration;
|
||||
condition_body_net - net to compute condition value
|
||||
"""
|
||||
input_blob_names, output_blob_names = get_external_blob_names(
|
||||
loop_body_net, lexical_scope)
|
||||
|
||||
# Since it's possible that loop is not going to run even once
|
||||
# we have to add loop's external outputs into inputs
|
||||
input_blob_names |= output_blob_names
|
||||
|
||||
loop_inputs = [core.BlobReference(name=b, net=None) for b in input_blob_names]
|
||||
loop_outputs = [core.BlobReference(name=b, net=None) for b in output_blob_names]
|
||||
|
||||
while_inputs = [cond_blob] + loop_inputs
|
||||
while_outputs = [] + loop_outputs
|
||||
|
||||
do_loop_body_net = core.Net('do_loop_body_net')
|
||||
|
||||
loop_input_output_names_ordered = [
|
||||
str(b) for b in (loop_inputs + loop_outputs)]
|
||||
loop_body_outer_blob_names = list(input_blob_names | output_blob_names)
|
||||
loop_body_outer_blob_names_idx = [
|
||||
loop_input_output_names_ordered.index(b) for b in loop_body_outer_blob_names]
|
||||
|
||||
do_loop_body_workspace_blob = \
|
||||
while_net.NextScopedBlob(while_net.Name() + '/workspace_loop_body')
|
||||
|
||||
loop_inputs.append(do_loop_body_workspace_blob)
|
||||
loop_outputs.append(do_loop_body_workspace_blob)
|
||||
# make sure that added workspace pointer blobs are in While inputs/outputs
|
||||
while_inputs.append(do_loop_body_workspace_blob)
|
||||
while_outputs.append(do_loop_body_workspace_blob)
|
||||
|
||||
do_loop_body_net.Do(
|
||||
loop_inputs,
|
||||
loop_outputs,
|
||||
net=loop_body_net.Proto(),
|
||||
inner_blobs=loop_body_outer_blob_names,
|
||||
outer_blobs_idx=loop_body_outer_blob_names_idx,
|
||||
copy_external_blobs=True)
|
||||
do_loop_body_net.AddExternalOutput(*loop_outputs)
|
||||
|
||||
while_args = {}
|
||||
while_args['loop_net'] = do_loop_body_net.Proto()
|
||||
|
||||
cond_workspace_blob = None
|
||||
if condition_body_net:
|
||||
cond_input_blob_names, cond_output_blob_names = get_external_blob_names(
|
||||
condition_body_net, lexical_scope)
|
||||
|
||||
# make sure condition blob is written by condition net and is
|
||||
# visible outside of it
|
||||
found_condition_output = False
|
||||
for op in condition_body_net.Proto().op:
|
||||
if str(cond_blob) in op.output:
|
||||
found_condition_output = True
|
||||
break
|
||||
assert found_condition_output, \
|
||||
"Condition net does not write into condition blob"
|
||||
if str(cond_blob) not in cond_output_blob_names:
|
||||
cond_output_blob_names.add(str(cond_blob))
|
||||
|
||||
cond_inputs = [core.BlobReference(name=b, net=None)
|
||||
for b in cond_input_blob_names]
|
||||
assert str(cond_blob) in cond_output_blob_names, \
|
||||
'Condition blob expected in condition net output'
|
||||
cond_outputs = [core.BlobReference(name=b, net=None)
|
||||
for b in cond_output_blob_names]
|
||||
|
||||
condition_net = core.Net('do_loop_condition_net')
|
||||
|
||||
cond_input_output_names_ordered = [
|
||||
str(b) for b in (cond_inputs + cond_outputs)]
|
||||
cond_body_outer_blob_names = \
|
||||
list(cond_input_blob_names | cond_output_blob_names)
|
||||
cond_body_outer_blob_names_idx = [
|
||||
cond_input_output_names_ordered.index(b)
|
||||
for b in cond_body_outer_blob_names]
|
||||
|
||||
cond_workspace_blob = \
|
||||
while_net.NextScopedBlob(while_net.Name() + '/workspace_loop_cond')
|
||||
cond_inputs.append(cond_workspace_blob)
|
||||
cond_outputs.append(cond_workspace_blob)
|
||||
|
||||
condition_net.Do(
|
||||
cond_inputs,
|
||||
cond_outputs,
|
||||
net=condition_body_net.Proto(),
|
||||
inner_blobs=cond_body_outer_blob_names,
|
||||
outer_blobs_idx=cond_body_outer_blob_names_idx)
|
||||
condition_net.AddExternalOutput(*cond_outputs)
|
||||
|
||||
while_args['cond_net'] = condition_net.Proto()
|
||||
|
||||
while_inputs += [b for b in cond_inputs
|
||||
if str(b) not in input_blob_names]
|
||||
while_outputs += [b for b in cond_outputs
|
||||
if str(b) not in output_blob_names]
|
||||
|
||||
if str(cond_blob) not in lexical_scope:
|
||||
while_net.ConstantFill(
|
||||
[],
|
||||
cond_blob,
|
||||
dtype=core.DataType.BOOL,
|
||||
value=False)
|
||||
|
||||
while_net.CreateScope([], [do_loop_body_workspace_blob])
|
||||
if cond_workspace_blob:
|
||||
while_net.CreateScope([], [cond_workspace_blob])
|
||||
while_net.While(while_inputs, while_outputs, **while_args)
|
||||
while_net.AddExternalOutput(*while_outputs)
|
||||
331
caffe2/python/control_test.py
Normal file
331
caffe2/python/control_test.py
Normal file
@ -0,0 +1,331 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import control, core, test_util, workspace
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestControl(test_util.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.N_ = 10
|
||||
|
||||
self.init_net_ = core.Net("init-net")
|
||||
cnt = self.init_net_.CreateCounter([], init_count=0)
|
||||
const_n = self.init_net_.ConstantFill(
|
||||
[], shape=[], value=self.N_, dtype=core.DataType.INT64)
|
||||
const_0 = self.init_net_.ConstantFill(
|
||||
[], shape=[], value=0, dtype=core.DataType.INT64)
|
||||
|
||||
self.cnt_net_ = core.Net("cnt-net")
|
||||
self.cnt_net_.CountUp([cnt])
|
||||
curr_cnt = self.cnt_net_.RetrieveCount([cnt])
|
||||
self.init_net_.ConstantFill(
|
||||
[], [curr_cnt], shape=[], value=0, dtype=core.DataType.INT64)
|
||||
self.cnt_net_.AddExternalOutput(curr_cnt)
|
||||
|
||||
self.cnt_2_net_ = core.Net("cnt-2-net")
|
||||
self.cnt_2_net_.CountUp([cnt])
|
||||
self.cnt_2_net_.CountUp([cnt])
|
||||
curr_cnt_2 = self.cnt_2_net_.RetrieveCount([cnt])
|
||||
self.init_net_.ConstantFill(
|
||||
[], [curr_cnt_2], shape=[], value=0, dtype=core.DataType.INT64)
|
||||
self.cnt_2_net_.AddExternalOutput(curr_cnt_2)
|
||||
|
||||
self.cond_net_ = core.Net("cond-net")
|
||||
cond_blob = self.cond_net_.LT([curr_cnt, const_n])
|
||||
self.cond_net_.AddExternalOutput(cond_blob)
|
||||
|
||||
self.not_cond_net_ = core.Net("not-cond-net")
|
||||
cond_blob = self.not_cond_net_.GE([curr_cnt, const_n])
|
||||
self.not_cond_net_.AddExternalOutput(cond_blob)
|
||||
|
||||
self.true_cond_net_ = core.Net("true-cond-net")
|
||||
true_blob = self.true_cond_net_.LT([const_0, const_n])
|
||||
self.true_cond_net_.AddExternalOutput(true_blob)
|
||||
|
||||
self.false_cond_net_ = core.Net("false-cond-net")
|
||||
false_blob = self.false_cond_net_.GT([const_0, const_n])
|
||||
self.false_cond_net_.AddExternalOutput(false_blob)
|
||||
|
||||
self.idle_net_ = core.Net("idle-net")
|
||||
self.idle_net_.ConstantFill(
|
||||
[], shape=[], value=0, dtype=core.DataType.INT64)
|
||||
|
||||
def CheckNetOutput(self, nets_and_expects):
|
||||
"""
|
||||
Check the net output is expected
|
||||
nets_and_expects is a list of tuples (net, expect)
|
||||
"""
|
||||
for net, expect in nets_and_expects:
|
||||
output = workspace.FetchBlob(
|
||||
net.Proto().external_output[-1])
|
||||
self.assertEqual(output, expect)
|
||||
|
||||
def CheckNetAllOutput(self, net, expects):
|
||||
"""
|
||||
Check the net output is expected
|
||||
expects is a list of bools.
|
||||
"""
|
||||
self.assertEqual(len(net.Proto().external_output), len(expects))
|
||||
for i in range(len(expects)):
|
||||
output = workspace.FetchBlob(
|
||||
net.Proto().external_output[i])
|
||||
self.assertEqual(output, expects[i])
|
||||
|
||||
def BuildAndRunPlan(self, step):
|
||||
plan = core.Plan("test")
|
||||
plan.AddStep(control.Do('init', self.init_net_))
|
||||
plan.AddStep(step)
|
||||
self.assertEqual(workspace.RunPlan(plan), True)
|
||||
|
||||
def ForLoopTest(self, nets_or_steps):
|
||||
step = control.For('myFor', nets_or_steps, self.N_)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(self.cnt_net_, self.N_)])
|
||||
|
||||
def testForLoopWithNets(self):
|
||||
self.ForLoopTest(self.cnt_net_)
|
||||
self.ForLoopTest([self.cnt_net_, self.idle_net_])
|
||||
|
||||
def testForLoopWithStep(self):
|
||||
step = control.Do('count', self.cnt_net_)
|
||||
self.ForLoopTest(step)
|
||||
self.ForLoopTest([step, self.idle_net_])
|
||||
|
||||
def WhileLoopTest(self, nets_or_steps):
|
||||
step = control.While('myWhile', self.cond_net_, nets_or_steps)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(self.cnt_net_, self.N_)])
|
||||
|
||||
def testWhileLoopWithNet(self):
|
||||
self.WhileLoopTest(self.cnt_net_)
|
||||
self.WhileLoopTest([self.cnt_net_, self.idle_net_])
|
||||
|
||||
def testWhileLoopWithStep(self):
|
||||
step = control.Do('count', self.cnt_net_)
|
||||
self.WhileLoopTest(step)
|
||||
self.WhileLoopTest([step, self.idle_net_])
|
||||
|
||||
def UntilLoopTest(self, nets_or_steps):
|
||||
step = control.Until('myUntil', self.not_cond_net_, nets_or_steps)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(self.cnt_net_, self.N_)])
|
||||
|
||||
def testUntilLoopWithNet(self):
|
||||
self.UntilLoopTest(self.cnt_net_)
|
||||
self.UntilLoopTest([self.cnt_net_, self.idle_net_])
|
||||
|
||||
def testUntilLoopWithStep(self):
|
||||
step = control.Do('count', self.cnt_net_)
|
||||
self.UntilLoopTest(step)
|
||||
self.UntilLoopTest([step, self.idle_net_])
|
||||
|
||||
def DoWhileLoopTest(self, nets_or_steps):
|
||||
step = control.DoWhile('myDoWhile', self.cond_net_, nets_or_steps)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(self.cnt_net_, self.N_)])
|
||||
|
||||
def testDoWhileLoopWithNet(self):
|
||||
self.DoWhileLoopTest(self.cnt_net_)
|
||||
self.DoWhileLoopTest([self.idle_net_, self.cnt_net_])
|
||||
|
||||
def testDoWhileLoopWithStep(self):
|
||||
step = control.Do('count', self.cnt_net_)
|
||||
self.DoWhileLoopTest(step)
|
||||
self.DoWhileLoopTest([self.idle_net_, step])
|
||||
|
||||
def DoUntilLoopTest(self, nets_or_steps):
|
||||
step = control.DoUntil('myDoUntil', self.not_cond_net_, nets_or_steps)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(self.cnt_net_, self.N_)])
|
||||
|
||||
def testDoUntilLoopWithNet(self):
|
||||
self.DoUntilLoopTest(self.cnt_net_)
|
||||
self.DoUntilLoopTest([self.cnt_net_, self.idle_net_])
|
||||
|
||||
def testDoUntilLoopWithStep(self):
|
||||
step = control.Do('count', self.cnt_net_)
|
||||
self.DoUntilLoopTest(step)
|
||||
self.DoUntilLoopTest([self.idle_net_, step])
|
||||
|
||||
def IfCondTest(self, cond_net, expect, cond_on_blob):
|
||||
if cond_on_blob:
|
||||
step = control.Do(
|
||||
'if-all',
|
||||
control.Do('count', cond_net),
|
||||
control.If('myIf', cond_net.Proto().external_output[-1],
|
||||
self.cnt_net_))
|
||||
else:
|
||||
step = control.If('myIf', cond_net, self.cnt_net_)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(self.cnt_net_, expect)])
|
||||
|
||||
def testIfCondTrueOnNet(self):
|
||||
self.IfCondTest(self.true_cond_net_, 1, False)
|
||||
|
||||
def testIfCondTrueOnBlob(self):
|
||||
self.IfCondTest(self.true_cond_net_, 1, True)
|
||||
|
||||
def testIfCondFalseOnNet(self):
|
||||
self.IfCondTest(self.false_cond_net_, 0, False)
|
||||
|
||||
def testIfCondFalseOnBlob(self):
|
||||
self.IfCondTest(self.false_cond_net_, 0, True)
|
||||
|
||||
def IfElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
|
||||
if cond_value:
|
||||
run_net = self.cnt_net_
|
||||
else:
|
||||
run_net = self.cnt_2_net_
|
||||
if cond_on_blob:
|
||||
step = control.Do(
|
||||
'if-else-all',
|
||||
control.Do('count', cond_net),
|
||||
control.If('myIfElse', cond_net.Proto().external_output[-1],
|
||||
self.cnt_net_, self.cnt_2_net_))
|
||||
else:
|
||||
step = control.If('myIfElse', cond_net,
|
||||
self.cnt_net_, self.cnt_2_net_)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(run_net, expect)])
|
||||
|
||||
def testIfElseCondTrueOnNet(self):
|
||||
self.IfElseCondTest(self.true_cond_net_, True, 1, False)
|
||||
|
||||
def testIfElseCondTrueOnBlob(self):
|
||||
self.IfElseCondTest(self.true_cond_net_, True, 1, True)
|
||||
|
||||
def testIfElseCondFalseOnNet(self):
|
||||
self.IfElseCondTest(self.false_cond_net_, False, 2, False)
|
||||
|
||||
def testIfElseCondFalseOnBlob(self):
|
||||
self.IfElseCondTest(self.false_cond_net_, False, 2, True)
|
||||
|
||||
def IfNotCondTest(self, cond_net, expect, cond_on_blob):
|
||||
if cond_on_blob:
|
||||
step = control.Do(
|
||||
'if-not',
|
||||
control.Do('count', cond_net),
|
||||
control.IfNot('myIfNot', cond_net.Proto().external_output[-1],
|
||||
self.cnt_net_))
|
||||
else:
|
||||
step = control.IfNot('myIfNot', cond_net, self.cnt_net_)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(self.cnt_net_, expect)])
|
||||
|
||||
def testIfNotCondTrueOnNet(self):
|
||||
self.IfNotCondTest(self.true_cond_net_, 0, False)
|
||||
|
||||
def testIfNotCondTrueOnBlob(self):
|
||||
self.IfNotCondTest(self.true_cond_net_, 0, True)
|
||||
|
||||
def testIfNotCondFalseOnNet(self):
|
||||
self.IfNotCondTest(self.false_cond_net_, 1, False)
|
||||
|
||||
def testIfNotCondFalseOnBlob(self):
|
||||
self.IfNotCondTest(self.false_cond_net_, 1, True)
|
||||
|
||||
def IfNotElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
|
||||
if cond_value:
|
||||
run_net = self.cnt_2_net_
|
||||
else:
|
||||
run_net = self.cnt_net_
|
||||
if cond_on_blob:
|
||||
step = control.Do(
|
||||
'if-not-else',
|
||||
control.Do('count', cond_net),
|
||||
control.IfNot('myIfNotElse',
|
||||
cond_net.Proto().external_output[-1],
|
||||
self.cnt_net_, self.cnt_2_net_))
|
||||
else:
|
||||
step = control.IfNot('myIfNotElse', cond_net,
|
||||
self.cnt_net_, self.cnt_2_net_)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(run_net, expect)])
|
||||
|
||||
def testIfNotElseCondTrueOnNet(self):
|
||||
self.IfNotElseCondTest(self.true_cond_net_, True, 2, False)
|
||||
|
||||
def testIfNotElseCondTrueOnBlob(self):
|
||||
self.IfNotElseCondTest(self.true_cond_net_, True, 2, True)
|
||||
|
||||
def testIfNotElseCondFalseOnNet(self):
|
||||
self.IfNotElseCondTest(self.false_cond_net_, False, 1, False)
|
||||
|
||||
def testIfNotElseCondFalseOnBlob(self):
|
||||
self.IfNotElseCondTest(self.false_cond_net_, False, 1, True)
|
||||
|
||||
def testSwitch(self):
|
||||
step = control.Switch(
|
||||
'mySwitch',
|
||||
(self.false_cond_net_, self.cnt_net_),
|
||||
(self.true_cond_net_, self.cnt_2_net_)
|
||||
)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(self.cnt_net_, 0), (self.cnt_2_net_, 2)])
|
||||
|
||||
def testSwitchNot(self):
|
||||
step = control.SwitchNot(
|
||||
'mySwitchNot',
|
||||
(self.false_cond_net_, self.cnt_net_),
|
||||
(self.true_cond_net_, self.cnt_2_net_)
|
||||
)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(self.cnt_net_, 1), (self.cnt_2_net_, 0)])
|
||||
|
||||
def testBoolNet(self):
|
||||
bool_net = control.BoolNet(('a', True))
|
||||
step = control.Do('bool', bool_net)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetAllOutput(bool_net, [True])
|
||||
|
||||
bool_net = control.BoolNet(('a', True), ('b', False))
|
||||
step = control.Do('bool', bool_net)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetAllOutput(bool_net, [True, False])
|
||||
|
||||
bool_net = control.BoolNet([('a', True), ('b', False)])
|
||||
step = control.Do('bool', bool_net)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetAllOutput(bool_net, [True, False])
|
||||
|
||||
def testCombineConditions(self):
|
||||
# combined by 'Or'
|
||||
combine_net = control.CombineConditions(
|
||||
'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
|
||||
step = control.Do('combine',
|
||||
self.true_cond_net_,
|
||||
self.false_cond_net_,
|
||||
combine_net)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(combine_net, True)])
|
||||
|
||||
# combined by 'And'
|
||||
combine_net = control.CombineConditions(
|
||||
'test', [self.true_cond_net_, self.false_cond_net_], 'And')
|
||||
step = control.Do('combine',
|
||||
self.true_cond_net_,
|
||||
self.false_cond_net_,
|
||||
combine_net)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(combine_net, False)])
|
||||
|
||||
def testMergeConditionNets(self):
|
||||
# merged by 'Or'
|
||||
merge_net = control.MergeConditionNets(
|
||||
'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
|
||||
step = control.Do('merge', merge_net)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(merge_net, True)])
|
||||
|
||||
# merged by 'And'
|
||||
merge_net = control.MergeConditionNets(
|
||||
'test', [self.true_cond_net_, self.false_cond_net_], 'And')
|
||||
step = control.Do('merge', merge_net)
|
||||
self.BuildAndRunPlan(step)
|
||||
self.CheckNetOutput([(merge_net, False)])
|
||||
2
caffe2/python/convert.py
Normal file
2
caffe2/python/convert.py
Normal file
@ -0,0 +1,2 @@
|
||||
## @package workspace
|
||||
# Module caffe2.python.workspace
|
||||
14
caffe2/python/convert_test.py
Normal file
14
caffe2/python/convert_test.py
Normal file
@ -0,0 +1,14 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import workspace
|
||||
import unittest
|
||||
|
||||
class TestOperator(unittest.TestCase):
|
||||
def setUp(self):
|
||||
workspace.ResetWorkspace()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
727
caffe2/python/convnet_benchmarks.py
Normal file
727
caffe2/python/convnet_benchmarks.py
Normal file
@ -0,0 +1,727 @@
|
||||
## @package convnet_benchmarks
|
||||
# Module caffe2.python.convnet_benchmarks
|
||||
"""
|
||||
Benchmark for common convnets.
|
||||
|
||||
Speed on Titan X, with 10 warmup steps and 10 main steps and with different
|
||||
versions of cudnn, are as follows (time reported below is per-batch time,
|
||||
forward / forward+backward):
|
||||
|
||||
CuDNN V3 CuDNN v4
|
||||
AlexNet 32.5 / 108.0 27.4 / 90.1
|
||||
OverFeat 113.0 / 342.3 91.7 / 276.5
|
||||
Inception 134.5 / 485.8 125.7 / 450.6
|
||||
VGG (batch 64) 200.8 / 650.0 164.1 / 551.7
|
||||
|
||||
Speed on Inception with varied batch sizes and CuDNN v4 is as follows:
|
||||
|
||||
Batch Size Speed per batch Speed per image
|
||||
16 22.8 / 72.7 1.43 / 4.54
|
||||
32 38.0 / 127.5 1.19 / 3.98
|
||||
64 67.2 / 233.6 1.05 / 3.65
|
||||
128 125.7 / 450.6 0.98 / 3.52
|
||||
|
||||
Speed on Tesla M40, which 10 warmup steps and 10 main steps and with cudnn
|
||||
v4, is as follows:
|
||||
|
||||
AlexNet 68.4 / 218.1
|
||||
OverFeat 210.5 / 630.3
|
||||
Inception 300.2 / 1122.2
|
||||
VGG (batch 64) 405.8 / 1327.7
|
||||
|
||||
(Note that these numbers involve a "full" backprop, i.e. the gradient
|
||||
with respect to the input image is also computed.)
|
||||
|
||||
To get the numbers, simply run:
|
||||
|
||||
for MODEL in AlexNet OverFeat Inception; do
|
||||
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
|
||||
--batch_size 128 --model $MODEL --forward_only True
|
||||
done
|
||||
for MODEL in AlexNet OverFeat Inception; do
|
||||
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
|
||||
--batch_size 128 --model $MODEL
|
||||
done
|
||||
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
|
||||
--batch_size 64 --model VGGA --forward_only True
|
||||
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
|
||||
--batch_size 64 --model VGGA
|
||||
|
||||
for BS in 16 32 64 128; do
|
||||
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
|
||||
--batch_size $BS --model Inception --forward_only True
|
||||
PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
|
||||
--batch_size $BS --model Inception
|
||||
done
|
||||
|
||||
Note that VGG needs to be run at batch 64 due to memory limit on the backward
|
||||
pass.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from caffe2.python import workspace, brew, model_helper
|
||||
|
||||
|
||||
def MLP(order, cudnn_ws):
|
||||
model = model_helper.ModelHelper(name="MLP")
|
||||
d = 256
|
||||
depth = 20
|
||||
width = 3
|
||||
for i in range(depth):
|
||||
for j in range(width):
|
||||
current = "fc_{}_{}".format(i, j) if i > 0 else "data"
|
||||
next_ = "fc_{}_{}".format(i + 1, j)
|
||||
brew.fc(
|
||||
model,
|
||||
current,
|
||||
next_,
|
||||
dim_in=d,
|
||||
dim_out=d,
|
||||
weight_init=('XavierFill', {}),
|
||||
bias_init=('XavierFill', {}),
|
||||
)
|
||||
brew.sum(
|
||||
model, ["fc_{}_{}".format(depth, j) for j in range(width)], ["sum"]
|
||||
)
|
||||
brew.fc(
|
||||
model,
|
||||
"sum",
|
||||
"last",
|
||||
dim_in=d,
|
||||
dim_out=1000,
|
||||
weight_init=('XavierFill', {}),
|
||||
bias_init=('XavierFill', {}),
|
||||
)
|
||||
xent = model.net.LabelCrossEntropy(["last", "label"], "xent")
|
||||
model.net.AveragedLoss(xent, "loss")
|
||||
return model, d
|
||||
|
||||
|
||||
def AlexNet(order, cudnn_ws):
|
||||
my_arg_scope = {
|
||||
'order': order,
|
||||
'use_cudnn': True,
|
||||
'cudnn_exhaustive_search': True,
|
||||
}
|
||||
if cudnn_ws:
|
||||
my_arg_scope['ws_nbytes_limit'] = cudnn_ws
|
||||
model = model_helper.ModelHelper(
|
||||
name="alexnet",
|
||||
arg_scope=my_arg_scope,
|
||||
)
|
||||
conv1 = brew.conv(
|
||||
model,
|
||||
"data",
|
||||
"conv1",
|
||||
3,
|
||||
64,
|
||||
11, ('XavierFill', {}), ('ConstantFill', {}),
|
||||
stride=4,
|
||||
pad=2
|
||||
)
|
||||
relu1 = brew.relu(model, conv1, "conv1")
|
||||
pool1 = brew.max_pool(model, relu1, "pool1", kernel=3, stride=2)
|
||||
conv2 = brew.conv(
|
||||
model,
|
||||
pool1,
|
||||
"conv2",
|
||||
64,
|
||||
192,
|
||||
5,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=2
|
||||
)
|
||||
relu2 = brew.relu(model, conv2, "conv2")
|
||||
pool2 = brew.max_pool(model, relu2, "pool2", kernel=3, stride=2)
|
||||
conv3 = brew.conv(
|
||||
model,
|
||||
pool2,
|
||||
"conv3",
|
||||
192,
|
||||
384,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1
|
||||
)
|
||||
relu3 = brew.relu(model, conv3, "conv3")
|
||||
conv4 = brew.conv(
|
||||
model,
|
||||
relu3,
|
||||
"conv4",
|
||||
384,
|
||||
256,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1
|
||||
)
|
||||
relu4 = brew.relu(model, conv4, "conv4")
|
||||
conv5 = brew.conv(
|
||||
model,
|
||||
relu4,
|
||||
"conv5",
|
||||
256,
|
||||
256,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1
|
||||
)
|
||||
relu5 = brew.relu(model, conv5, "conv5")
|
||||
pool5 = brew.max_pool(model, relu5, "pool5", kernel=3, stride=2)
|
||||
fc6 = brew.fc(
|
||||
model,
|
||||
pool5, "fc6", 256 * 6 * 6, 4096, ('XavierFill', {}),
|
||||
('ConstantFill', {})
|
||||
)
|
||||
relu6 = brew.relu(model, fc6, "fc6")
|
||||
fc7 = brew.fc(
|
||||
model, relu6, "fc7", 4096, 4096, ('XavierFill', {}), ('ConstantFill', {})
|
||||
)
|
||||
relu7 = brew.relu(model, fc7, "fc7")
|
||||
fc8 = brew.fc(
|
||||
model, relu7, "fc8", 4096, 1000, ('XavierFill', {}), ('ConstantFill', {})
|
||||
)
|
||||
pred = brew.softmax(model, fc8, "pred")
|
||||
xent = model.net.LabelCrossEntropy([pred, "label"], "xent")
|
||||
model.net.AveragedLoss(xent, "loss")
|
||||
return model, 224
|
||||
|
||||
|
||||
def OverFeat(order, cudnn_ws):
|
||||
my_arg_scope = {
|
||||
'order': order,
|
||||
'use_cudnn': True,
|
||||
'cudnn_exhaustive_search': True,
|
||||
}
|
||||
if cudnn_ws:
|
||||
my_arg_scope['ws_nbytes_limit'] = cudnn_ws
|
||||
model = model_helper.ModelHelper(
|
||||
name="overfeat",
|
||||
arg_scope=my_arg_scope,
|
||||
)
|
||||
conv1 = brew.conv(
|
||||
model,
|
||||
"data",
|
||||
"conv1",
|
||||
3,
|
||||
96,
|
||||
11,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
stride=4,
|
||||
)
|
||||
relu1 = brew.relu(model, conv1, "conv1")
|
||||
pool1 = brew.max_pool(model, relu1, "pool1", kernel=2, stride=2)
|
||||
conv2 = brew.conv(
|
||||
model, pool1, "conv2", 96, 256, 5, ('XavierFill', {}),
|
||||
('ConstantFill', {})
|
||||
)
|
||||
relu2 = brew.relu(model, conv2, "conv2")
|
||||
pool2 = brew.max_pool(model, relu2, "pool2", kernel=2, stride=2)
|
||||
conv3 = brew.conv(
|
||||
model,
|
||||
pool2,
|
||||
"conv3",
|
||||
256,
|
||||
512,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu3 = brew.relu(model, conv3, "conv3")
|
||||
conv4 = brew.conv(
|
||||
model,
|
||||
relu3,
|
||||
"conv4",
|
||||
512,
|
||||
1024,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu4 = brew.relu(model, conv4, "conv4")
|
||||
conv5 = brew.conv(
|
||||
model,
|
||||
relu4,
|
||||
"conv5",
|
||||
1024,
|
||||
1024,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu5 = brew.relu(model, conv5, "conv5")
|
||||
pool5 = brew.max_pool(model, relu5, "pool5", kernel=2, stride=2)
|
||||
fc6 = brew.fc(
|
||||
model, pool5, "fc6", 1024 * 6 * 6, 3072, ('XavierFill', {}),
|
||||
('ConstantFill', {})
|
||||
)
|
||||
relu6 = brew.relu(model, fc6, "fc6")
|
||||
fc7 = brew.fc(
|
||||
model, relu6, "fc7", 3072, 4096, ('XavierFill', {}), ('ConstantFill', {})
|
||||
)
|
||||
relu7 = brew.relu(model, fc7, "fc7")
|
||||
fc8 = brew.fc(
|
||||
model, relu7, "fc8", 4096, 1000, ('XavierFill', {}), ('ConstantFill', {})
|
||||
)
|
||||
pred = brew.softmax(model, fc8, "pred")
|
||||
xent = model.net.LabelCrossEntropy([pred, "label"], "xent")
|
||||
model.net.AveragedLoss(xent, "loss")
|
||||
return model, 231
|
||||
|
||||
|
||||
def VGGA(order, cudnn_ws):
|
||||
my_arg_scope = {
|
||||
'order': order,
|
||||
'use_cudnn': True,
|
||||
'cudnn_exhaustive_search': True,
|
||||
}
|
||||
if cudnn_ws:
|
||||
my_arg_scope['ws_nbytes_limit'] = cudnn_ws
|
||||
model = model_helper.ModelHelper(
|
||||
name="vgga",
|
||||
arg_scope=my_arg_scope,
|
||||
)
|
||||
conv1 = brew.conv(
|
||||
model,
|
||||
"data",
|
||||
"conv1",
|
||||
3,
|
||||
64,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu1 = brew.relu(model, conv1, "conv1")
|
||||
pool1 = brew.max_pool(model, relu1, "pool1", kernel=2, stride=2)
|
||||
conv2 = brew.conv(
|
||||
model,
|
||||
pool1,
|
||||
"conv2",
|
||||
64,
|
||||
128,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu2 = brew.relu(model, conv2, "conv2")
|
||||
pool2 = brew.max_pool(model, relu2, "pool2", kernel=2, stride=2)
|
||||
conv3 = brew.conv(
|
||||
model,
|
||||
pool2,
|
||||
"conv3",
|
||||
128,
|
||||
256,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu3 = brew.relu(model, conv3, "conv3")
|
||||
conv4 = brew.conv(
|
||||
model,
|
||||
relu3,
|
||||
"conv4",
|
||||
256,
|
||||
256,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu4 = brew.relu(model, conv4, "conv4")
|
||||
pool4 = brew.max_pool(model, relu4, "pool4", kernel=2, stride=2)
|
||||
conv5 = brew.conv(
|
||||
model,
|
||||
pool4,
|
||||
"conv5",
|
||||
256,
|
||||
512,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu5 = brew.relu(model, conv5, "conv5")
|
||||
conv6 = brew.conv(
|
||||
model,
|
||||
relu5,
|
||||
"conv6",
|
||||
512,
|
||||
512,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu6 = brew.relu(model, conv6, "conv6")
|
||||
pool6 = brew.max_pool(model, relu6, "pool6", kernel=2, stride=2)
|
||||
conv7 = brew.conv(
|
||||
model,
|
||||
pool6,
|
||||
"conv7",
|
||||
512,
|
||||
512,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu7 = brew.relu(model, conv7, "conv7")
|
||||
conv8 = brew.conv(
|
||||
model,
|
||||
relu7,
|
||||
"conv8",
|
||||
512,
|
||||
512,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu8 = brew.relu(model, conv8, "conv8")
|
||||
pool8 = brew.max_pool(model, relu8, "pool8", kernel=2, stride=2)
|
||||
|
||||
fcix = brew.fc(
|
||||
model, pool8, "fcix", 512 * 7 * 7, 4096, ('XavierFill', {}),
|
||||
('ConstantFill', {})
|
||||
)
|
||||
reluix = brew.relu(model, fcix, "fcix")
|
||||
fcx = brew.fc(
|
||||
model, reluix, "fcx", 4096, 4096, ('XavierFill', {}),
|
||||
('ConstantFill', {})
|
||||
)
|
||||
relux = brew.relu(model, fcx, "fcx")
|
||||
fcxi = brew.fc(
|
||||
model, relux, "fcxi", 4096, 1000, ('XavierFill', {}),
|
||||
('ConstantFill', {})
|
||||
)
|
||||
pred = brew.softmax(model, fcxi, "pred")
|
||||
xent = model.net.LabelCrossEntropy([pred, "label"], "xent")
|
||||
model.net.AveragedLoss(xent, "loss")
|
||||
return model, 231
|
||||
|
||||
|
||||
def _InceptionModule(
|
||||
model, input_blob, input_depth, output_name, conv1_depth, conv3_depths,
|
||||
conv5_depths, pool_depth
|
||||
):
|
||||
# path 1: 1x1 conv
|
||||
conv1 = brew.conv(
|
||||
model, input_blob, output_name + ":conv1", input_depth, conv1_depth, 1,
|
||||
('XavierFill', {}), ('ConstantFill', {})
|
||||
)
|
||||
conv1 = brew.relu(model, conv1, conv1)
|
||||
# path 2: 1x1 conv + 3x3 conv
|
||||
conv3_reduce = brew.conv(
|
||||
model, input_blob, output_name + ":conv3_reduce", input_depth,
|
||||
conv3_depths[0], 1, ('XavierFill', {}), ('ConstantFill', {})
|
||||
)
|
||||
conv3_reduce = brew.relu(model, conv3_reduce, conv3_reduce)
|
||||
conv3 = brew.conv(
|
||||
model,
|
||||
conv3_reduce,
|
||||
output_name + ":conv3",
|
||||
conv3_depths[0],
|
||||
conv3_depths[1],
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
conv3 = brew.relu(model, conv3, conv3)
|
||||
# path 3: 1x1 conv + 5x5 conv
|
||||
conv5_reduce = brew.conv(
|
||||
model, input_blob, output_name + ":conv5_reduce", input_depth,
|
||||
conv5_depths[0], 1, ('XavierFill', {}), ('ConstantFill', {})
|
||||
)
|
||||
conv5_reduce = brew.relu(model, conv5_reduce, conv5_reduce)
|
||||
conv5 = brew.conv(
|
||||
model,
|
||||
conv5_reduce,
|
||||
output_name + ":conv5",
|
||||
conv5_depths[0],
|
||||
conv5_depths[1],
|
||||
5,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=2,
|
||||
)
|
||||
conv5 = brew.relu(model, conv5, conv5)
|
||||
# path 4: pool + 1x1 conv
|
||||
pool = brew.max_pool(
|
||||
model,
|
||||
input_blob,
|
||||
output_name + ":pool",
|
||||
kernel=3,
|
||||
stride=1,
|
||||
pad=1,
|
||||
)
|
||||
pool_proj = brew.conv(
|
||||
model, pool, output_name + ":pool_proj", input_depth, pool_depth, 1,
|
||||
('XavierFill', {}), ('ConstantFill', {})
|
||||
)
|
||||
pool_proj = brew.relu(model, pool_proj, pool_proj)
|
||||
output = brew.concat(model, [conv1, conv3, conv5, pool_proj], output_name)
|
||||
return output
|
||||
|
||||
|
||||
def Inception(order, cudnn_ws):
|
||||
my_arg_scope = {
|
||||
'order': order,
|
||||
'use_cudnn': True,
|
||||
'cudnn_exhaustive_search': True,
|
||||
}
|
||||
if cudnn_ws:
|
||||
my_arg_scope['ws_nbytes_limit'] = cudnn_ws
|
||||
model = model_helper.ModelHelper(
|
||||
name="inception",
|
||||
arg_scope=my_arg_scope,
|
||||
)
|
||||
conv1 = brew.conv(
|
||||
model,
|
||||
"data",
|
||||
"conv1",
|
||||
3,
|
||||
64,
|
||||
7,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
stride=2,
|
||||
pad=3,
|
||||
)
|
||||
relu1 = brew.relu(model, conv1, "conv1")
|
||||
pool1 = brew.max_pool(model, relu1, "pool1", kernel=3, stride=2, pad=1)
|
||||
conv2a = brew.conv(
|
||||
model, pool1, "conv2a", 64, 64, 1, ('XavierFill', {}),
|
||||
('ConstantFill', {})
|
||||
)
|
||||
conv2a = brew.relu(model, conv2a, conv2a)
|
||||
conv2 = brew.conv(
|
||||
model,
|
||||
conv2a,
|
||||
"conv2",
|
||||
64,
|
||||
192,
|
||||
3,
|
||||
('XavierFill', {}),
|
||||
('ConstantFill', {}),
|
||||
pad=1,
|
||||
)
|
||||
relu2 = brew.relu(model, conv2, "conv2")
|
||||
pool2 = brew.max_pool(model, relu2, "pool2", kernel=3, stride=2, pad=1)
|
||||
# Inception modules
|
||||
inc3 = _InceptionModule(
|
||||
model, pool2, 192, "inc3", 64, [96, 128], [16, 32], 32
|
||||
)
|
||||
inc4 = _InceptionModule(
|
||||
model, inc3, 256, "inc4", 128, [128, 192], [32, 96], 64
|
||||
)
|
||||
pool5 = brew.max_pool(model, inc4, "pool5", kernel=3, stride=2, pad=1)
|
||||
inc5 = _InceptionModule(
|
||||
model, pool5, 480, "inc5", 192, [96, 208], [16, 48], 64
|
||||
)
|
||||
inc6 = _InceptionModule(
|
||||
model, inc5, 512, "inc6", 160, [112, 224], [24, 64], 64
|
||||
)
|
||||
inc7 = _InceptionModule(
|
||||
model, inc6, 512, "inc7", 128, [128, 256], [24, 64], 64
|
||||
)
|
||||
inc8 = _InceptionModule(
|
||||
model, inc7, 512, "inc8", 112, [144, 288], [32, 64], 64
|
||||
)
|
||||
inc9 = _InceptionModule(
|
||||
model, inc8, 528, "inc9", 256, [160, 320], [32, 128], 128
|
||||
)
|
||||
pool9 = brew.max_pool(model, inc9, "pool9", kernel=3, stride=2, pad=1)
|
||||
inc10 = _InceptionModule(
|
||||
model, pool9, 832, "inc10", 256, [160, 320], [32, 128], 128
|
||||
)
|
||||
inc11 = _InceptionModule(
|
||||
model, inc10, 832, "inc11", 384, [192, 384], [48, 128], 128
|
||||
)
|
||||
pool11 = brew.average_pool(model, inc11, "pool11", kernel=7, stride=1)
|
||||
fc = brew.fc(
|
||||
model, pool11, "fc", 1024, 1000, ('XavierFill', {}),
|
||||
('ConstantFill', {})
|
||||
)
|
||||
# It seems that Soumith's benchmark does not have softmax on top
|
||||
# for Inception. We will add it anyway so we can have a proper
|
||||
# backward pass.
|
||||
pred = brew.softmax(model, fc, "pred")
|
||||
xent = model.net.LabelCrossEntropy([pred, "label"], "xent")
|
||||
model.net.AveragedLoss(xent, "loss")
|
||||
return model, 224
|
||||
|
||||
|
||||
def AddParameterUpdate(model):
|
||||
""" Simple plain SGD update -- not tuned to actually train the models """
|
||||
ITER = brew.iter(model, "iter")
|
||||
LR = model.net.LearningRate(
|
||||
ITER, "LR", base_lr=-1e-8, policy="step", stepsize=10000, gamma=0.999)
|
||||
ONE = model.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
|
||||
for param in model.params:
|
||||
param_grad = model.param_to_grad[param]
|
||||
model.net.WeightedSum([param, ONE, param_grad, LR], param)
|
||||
|
||||
|
||||
def Benchmark(model_gen, arg):
|
||||
model, input_size = model_gen(arg.order, arg.cudnn_ws)
|
||||
model.Proto().type = arg.net_type
|
||||
model.Proto().num_workers = arg.num_workers
|
||||
|
||||
# In order to be able to run everything without feeding more stuff, let's
|
||||
# add the data and label blobs to the parameter initialization net as well.
|
||||
if arg.order == "NCHW":
|
||||
input_shape = [arg.batch_size, 3, input_size, input_size]
|
||||
else:
|
||||
input_shape = [arg.batch_size, input_size, input_size, 3]
|
||||
if arg.model == "MLP":
|
||||
input_shape = [arg.batch_size, input_size]
|
||||
|
||||
model.param_init_net.GaussianFill(
|
||||
[],
|
||||
"data",
|
||||
shape=input_shape,
|
||||
mean=0.0,
|
||||
std=1.0
|
||||
)
|
||||
model.param_init_net.UniformIntFill(
|
||||
[],
|
||||
"label",
|
||||
shape=[arg.batch_size, ],
|
||||
min=0,
|
||||
max=999
|
||||
)
|
||||
|
||||
if arg.forward_only:
|
||||
print('{}: running forward only.'.format(arg.model))
|
||||
else:
|
||||
print('{}: running forward-backward.'.format(arg.model))
|
||||
model.AddGradientOperators(["loss"])
|
||||
AddParameterUpdate(model)
|
||||
if arg.order == 'NHWC':
|
||||
print(
|
||||
'==WARNING==\n'
|
||||
'NHWC order with CuDNN may not be supported yet, so I might\n'
|
||||
'exit suddenly.'
|
||||
)
|
||||
|
||||
if not arg.cpu:
|
||||
model.param_init_net.RunAllOnGPU()
|
||||
model.net.RunAllOnGPU()
|
||||
|
||||
if arg.engine:
|
||||
for op in model.net.Proto().op:
|
||||
op.engine = arg.engine
|
||||
|
||||
if arg.dump_model:
|
||||
# Writes out the pbtxt for benchmarks on e.g. Android
|
||||
with open(
|
||||
"{0}_init_batch_{1}.pbtxt".format(arg.model, arg.batch_size), "w"
|
||||
) as fid:
|
||||
fid.write(str(model.param_init_net.Proto()))
|
||||
with open("{0}.pbtxt".format(arg.model, arg.batch_size), "w") as fid:
|
||||
fid.write(str(model.net.Proto()))
|
||||
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.CreateNet(model.net)
|
||||
workspace.BenchmarkNet(
|
||||
model.net.Proto().name, arg.warmup_iterations, arg.iterations,
|
||||
arg.layer_wise_benchmark)
|
||||
|
||||
|
||||
def GetArgumentParser():
|
||||
parser = argparse.ArgumentParser(description="Caffe2 benchmark.")
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="The batch size."
|
||||
)
|
||||
parser.add_argument("--model", type=str, help="The model to benchmark.")
|
||||
parser.add_argument(
|
||||
"--order",
|
||||
type=str,
|
||||
default="NCHW",
|
||||
help="The order to evaluate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cudnn_ws",
|
||||
type=int,
|
||||
help="The cudnn workspace size."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations to run the network."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_iterations",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of warm-up iterations before benchmarking."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--forward_only",
|
||||
action='store_true',
|
||||
help="If set, only run the forward pass."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layer_wise_benchmark",
|
||||
action='store_true',
|
||||
help="If True, run the layer-wise benchmark as well."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu",
|
||||
action='store_true',
|
||||
help="If True, run testing on CPU instead of GPU."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--engine",
|
||||
type=str,
|
||||
default="",
|
||||
help="If set, blindly prefer the given engine(s) for every op.")
|
||||
parser.add_argument(
|
||||
"--dump_model",
|
||||
action='store_true',
|
||||
help="If True, dump the model prototxts to disk."
|
||||
)
|
||||
parser.add_argument("--net_type", type=str, default="dag")
|
||||
parser.add_argument("--num_workers", type=int, default=2)
|
||||
parser.add_argument("--use-nvtx", default=False, action='store_true')
|
||||
parser.add_argument("--htrace_span_log_path", type=str)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, extra_args = GetArgumentParser().parse_known_args()
|
||||
if (
|
||||
not args.batch_size or not args.model or not args.order
|
||||
):
|
||||
GetArgumentParser().print_help()
|
||||
else:
|
||||
workspace.GlobalInit(
|
||||
['caffe2', '--caffe2_log_level=0'] + extra_args +
|
||||
(['--caffe2_use_nvtx'] if args.use_nvtx else []) +
|
||||
(['--caffe2_htrace_span_log_path=' + args.htrace_span_log_path]
|
||||
if args.htrace_span_log_path else []))
|
||||
|
||||
model_map = {
|
||||
'AlexNet': AlexNet,
|
||||
'OverFeat': OverFeat,
|
||||
'VGGA': VGGA,
|
||||
'Inception': Inception,
|
||||
'MLP': MLP,
|
||||
}
|
||||
Benchmark(model_map[args.model], args)
|
||||
23
caffe2/python/convnet_benchmarks_test.py
Normal file
23
caffe2/python/convnet_benchmarks_test.py
Normal file
@ -0,0 +1,23 @@
|
||||
import unittest
|
||||
from caffe2.python import convnet_benchmarks as cb
|
||||
from caffe2.python import test_util, workspace
|
||||
|
||||
|
||||
# TODO: investigate why this randomly core dump in ROCM CI
|
||||
@unittest.skipIf(not workspace.has_cuda_support, "no cuda gpu")
|
||||
class TestConvnetBenchmarks(test_util.TestCase):
|
||||
def testConvnetBenchmarks(self):
|
||||
all_args = [
|
||||
'--batch_size 16 --order NCHW --iterations 1 '
|
||||
'--warmup_iterations 1',
|
||||
'--batch_size 16 --order NCHW --iterations 1 '
|
||||
'--warmup_iterations 1 --forward_only',
|
||||
]
|
||||
for model in [cb.AlexNet, cb.OverFeat, cb.VGGA, cb.Inception]:
|
||||
for arg_str in all_args:
|
||||
args = cb.GetArgumentParser().parse_args(arg_str.split(' '))
|
||||
cb.Benchmark(model, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
3070
caffe2/python/core.py
Normal file
3070
caffe2/python/core.py
Normal file
File diff suppressed because it is too large
Load Diff
1010
caffe2/python/core_gradients_test.py
Normal file
1010
caffe2/python/core_gradients_test.py
Normal file
File diff suppressed because it is too large
Load Diff
1264
caffe2/python/core_test.py
Normal file
1264
caffe2/python/core_test.py
Normal file
File diff suppressed because it is too large
Load Diff
313
caffe2/python/crf.py
Normal file
313
caffe2/python/crf.py
Normal file
@ -0,0 +1,313 @@
|
||||
## @package crf
|
||||
# Module caffe2.python.crf
|
||||
|
||||
|
||||
import numpy as np
|
||||
from caffe2.python import brew, core, model_helper, recurrent
|
||||
|
||||
|
||||
"""
|
||||
Due to a limitation in ReccurentNetworkOp, this layer only supports batch_size=1
|
||||
In order to support batch_size > 1, we will have to implement the CRFUnit
|
||||
and its gradient in C++ and handle the different batches there.
|
||||
"""
|
||||
|
||||
|
||||
class CRFWithLoss:
|
||||
def __init__(self, model, num_classes, transitions_blob=None):
|
||||
self.model = model
|
||||
self.num_classes = num_classes
|
||||
self.num_classes_padded = num_classes + 2 # After adding BOS and EOS
|
||||
if not transitions_blob:
|
||||
transitions_blob = self.model.param_init_net.UniformFill(
|
||||
[],
|
||||
[core.ScopedBlobReference("crf_transitions")],
|
||||
shape=[self.num_classes_padded, self.num_classes_padded],
|
||||
min=-1.0,
|
||||
max=1.0,
|
||||
)
|
||||
self.transitions = transitions_blob
|
||||
self.model.params.append(self.transitions)
|
||||
|
||||
def crf_loss(self, predictions, labels, seq_lengths=None):
|
||||
# Since the transitions matrix is a shared parameter, need to
|
||||
# take a snapshot of it at the beginning since it can be updated
|
||||
# in between the operators that uses it when doing parallel updates
|
||||
transitions_snapshot = self.model.net.Copy(
|
||||
self.transitions, core.ScopedBlobReference("transitions_snapshot")
|
||||
)
|
||||
# Compute best path unary score from the logits
|
||||
path_unary_score = self._gather_entries_sum(
|
||||
predictions, labels, self.num_classes
|
||||
)
|
||||
# Append BOS and EOS entries to the predictions and labels
|
||||
predictions = CRFWithLoss.pad_predictions(
|
||||
predictions, self.model.param_init_net, self.model.net, self.num_classes
|
||||
)
|
||||
labels = CRFWithLoss.pad_labels(
|
||||
labels, self.model.param_init_net, self.model.net, self.num_classes
|
||||
)
|
||||
# Compute best path binary scores from the transitions matrix
|
||||
path_binary_score = self._path_binary_scores(
|
||||
labels, transitions_snapshot, seq_lengths
|
||||
)
|
||||
path_total_score = self.model.net.Add(
|
||||
[path_binary_score, path_unary_score],
|
||||
core.ScopedBlobReference("path_total"),
|
||||
)
|
||||
# Compute all paths score
|
||||
zero_index = self.model.param_init_net.ConstantFill([], shape=[1], value=0)
|
||||
initial_state = self.model.net.Gather(
|
||||
[predictions, zero_index],
|
||||
core.ScopedBlobReference("rnn_initial"),
|
||||
dense_gradient=True,
|
||||
)
|
||||
input_data, _ = self.model.net.RemovePadding(
|
||||
[predictions], padding_width=1, end_padding_width=0, outputs=2
|
||||
)
|
||||
input_data = self.model.net.ExpandDims(
|
||||
[input_data], core.ScopedBlobReference("rnn_input_data"), dims=[1]
|
||||
)
|
||||
# Due to a bug in RecurrentNetworkGradientOp, we need to copy the
|
||||
# transitions blob before sending it to the recurrent network
|
||||
transitions_copy = self.model.net.Copy(
|
||||
transitions_snapshot, core.ScopedBlobReference("transitions_copy")
|
||||
)
|
||||
all_paths_scores = self._crf_forward(
|
||||
input_data, initial_state, transitions_copy
|
||||
)
|
||||
loss = self.model.net.Sub(
|
||||
[all_paths_scores, path_total_score], core.ScopedBlobReference("crf_loss")
|
||||
)
|
||||
return loss
|
||||
|
||||
def _path_binary_scores(self, labels, transitions, seq_lengths=None):
|
||||
column_ids, _ = self.model.net.RemovePadding(
|
||||
[labels], outputs=2, padding_width=1, end_padding_width=0
|
||||
)
|
||||
row_ids, _ = self.model.net.RemovePadding(
|
||||
[labels], outputs=2, padding_width=0, end_padding_width=1
|
||||
)
|
||||
# Since there is no multi-dimensional gather, I flatten the matrix to
|
||||
# a 1-d vector and transform the ids to (row_ids * num_columns +
|
||||
# column_ids) and do gather in 1-d
|
||||
num_columns_blob = self.model.net.ConstantFill(
|
||||
[row_ids], value=self.num_classes_padded
|
||||
)
|
||||
flattened_ids = self.model.net.Mul([row_ids, num_columns_blob])
|
||||
flattened_ids = self.model.net.Add([flattened_ids, column_ids])
|
||||
flattened_transitions = self.model.net.FlattenToVec([transitions])
|
||||
entries = self.model.net.Gather(
|
||||
[flattened_transitions, flattened_ids], dense_gradient=True
|
||||
)
|
||||
return self.model.ReduceFrontSum(entries)
|
||||
|
||||
def _gather_entries_sum(self, in_data, indices, index_size):
|
||||
indices = self.model.net.Cast([indices], to="int64")
|
||||
index_size_blob = self.model.param_init_net.ConstantFill(
|
||||
[], shape=[1], value=index_size
|
||||
)
|
||||
query_one_hot = self.model.net.OneHot([indices, index_size_blob])
|
||||
flattend_query = self.model.net.FlattenToVec(query_one_hot)
|
||||
flattend_data = self.model.net.FlattenToVec(in_data)
|
||||
query_scores = self.model.net.DotProduct([flattend_query, flattend_data])
|
||||
final_sum = self.model.net.ReduceFrontSum([query_scores])
|
||||
return final_sum
|
||||
|
||||
def _crf_forward(
|
||||
self, input_blob, initial_state, transitions_copy, seq_lengths=None
|
||||
):
|
||||
# Build the RNN net and get the last timestep output
|
||||
out_last = self.build_crf_net(input_blob, initial_state, transitions_copy)
|
||||
out_last, _ = self.model.net.Reshape(
|
||||
[out_last], outputs=2, shape=(self.num_classes_padded,)
|
||||
)
|
||||
zero_segment_id = self.model.param_init_net.ConstantFill(
|
||||
[], value=0, shape=[self.num_classes_padded], dtype=core.DataType.INT32
|
||||
)
|
||||
|
||||
# Compute the accumulated total score of all the paths
|
||||
accum_score = self.model.net.SortedSegmentRangeLogSumExp(
|
||||
[out_last, zero_segment_id]
|
||||
)
|
||||
accum_score, _ = self.model.net.Reshape(accum_score, outputs=2, shape=())
|
||||
return accum_score
|
||||
|
||||
def build_crf_net(self, input_blob, initial_state, transitions):
|
||||
"""
|
||||
Adds the crf_net recurrent operator to the model.
|
||||
|
||||
model: model_helper.ModelHelper object new operators would be added
|
||||
to
|
||||
|
||||
input_blob: the input sequence in a format T x N x D
|
||||
where T is sequence size, N - batch size and D - input dimension
|
||||
##Only supports batch-size 1##
|
||||
|
||||
seq_lengths: blob containing sequence lengths (unused)
|
||||
"""
|
||||
|
||||
scope = "crf_net"
|
||||
|
||||
def s(name):
|
||||
""
|
||||
# We have to manually scope due to our internal/external blob
|
||||
# relationships.
|
||||
return "{}/{}".format(str(scope), str(name))
|
||||
|
||||
step_model = model_helper.ModelHelper(name="crf_step", param_model=self.model)
|
||||
input_t, cell_t_prev, _ = step_model.net.AddExternalInputs(
|
||||
core.ScopedBlobReference("input_t"),
|
||||
core.ScopedBlobReference("cell_t_prev"),
|
||||
transitions,
|
||||
)
|
||||
zero_segment_id = step_model.param_init_net.ConstantFill(
|
||||
[],
|
||||
[s("zero_segment_id")],
|
||||
value=0,
|
||||
shape=[self.num_classes_padded],
|
||||
dtype=core.DataType.INT32,
|
||||
)
|
||||
|
||||
# A hack to bypass model cloning for test
|
||||
step_model.param_init_net.AddExternalOutput(zero_segment_id)
|
||||
""" the CRF step """
|
||||
# Do tile
|
||||
prev_transpose = brew.transpose(
|
||||
step_model, cell_t_prev, [s("prev_transpose")], axes=(0, 2, 1)
|
||||
)
|
||||
prev_tiled = step_model.net.Tile(
|
||||
prev_transpose, [s("prev_tiled")], tiles=self.num_classes_padded, axis=2
|
||||
)
|
||||
input_t_tiled = step_model.net.Tile(
|
||||
input_t, [s("input_t_tiled")], tiles=self.num_classes_padded, axis=1
|
||||
)
|
||||
input_with_prev = step_model.net.Add(
|
||||
[prev_tiled, input_t_tiled], [s("input_with_prev")]
|
||||
)
|
||||
all_with_transitions = step_model.net.Add(
|
||||
[input_with_prev, transitions],
|
||||
[s("prev_with_transitions")],
|
||||
broadcast=1,
|
||||
use_grad_hack=1,
|
||||
)
|
||||
all_with_transitions_reshaped, _ = step_model.net.Reshape(
|
||||
all_with_transitions,
|
||||
[s("all_with_transitions_reshaped"), s("all_with_transitions_orig")],
|
||||
shape=(self.num_classes_padded, self.num_classes_padded),
|
||||
)
|
||||
cell_t = step_model.net.SortedSegmentRangeLogSumExp(
|
||||
[all_with_transitions_reshaped, zero_segment_id], [s("cell_t")]
|
||||
)
|
||||
step_model.net.AddExternalOutputs(cell_t)
|
||||
""" recurrent network """
|
||||
cell_input_blob = initial_state
|
||||
out_all, out_last = recurrent.recurrent_net(
|
||||
net=self.model.net,
|
||||
cell_net=step_model.net,
|
||||
inputs=[(input_t, input_blob)],
|
||||
initial_cell_inputs=[(cell_t_prev, cell_input_blob)],
|
||||
links={cell_t_prev: cell_t},
|
||||
scope=scope,
|
||||
outputs_with_grads=(1,),
|
||||
)
|
||||
return out_last
|
||||
|
||||
def update_predictions(self, classes):
|
||||
def crf_update_predictions_op(inputs, outputs):
|
||||
# This operator will compute the best path of classes by performing
|
||||
# Viterbi decoding and then updates the predictions to make the tag
|
||||
# On the best path has the highest score among the others
|
||||
predictions = inputs[0].data
|
||||
transitions = inputs[1].data
|
||||
predictions = inputs[0].data
|
||||
predictions_shape = inputs[0].shape
|
||||
outputs[0].reshape(predictions_shape)
|
||||
|
||||
trellis = np.zeros(predictions_shape)
|
||||
backpointers = np.zeros(predictions_shape, dtype=np.int32)
|
||||
trellis[0] = predictions[0]
|
||||
|
||||
for t in range(1, predictions_shape[0]):
|
||||
v = np.expand_dims(trellis[t - 1], 1) + transitions
|
||||
trellis[t] = predictions[t] + np.max(v, 0)
|
||||
backpointers[t] = np.argmax(v, 0)
|
||||
|
||||
viterbi = [np.argmax(trellis[-1])]
|
||||
for bp in reversed(backpointers[1:]):
|
||||
viterbi.append(bp[viterbi[-1]])
|
||||
viterbi.reverse()
|
||||
|
||||
new_predictions = np.zeros(predictions_shape)
|
||||
old_bests = []
|
||||
for i, w_predictions in enumerate(predictions):
|
||||
# Get the current tag with the maximum score
|
||||
new_predictions[i] = predictions[i]
|
||||
old_best = np.argmax(w_predictions)
|
||||
old_bests.append(old_best)
|
||||
# Swap the scores of the current best tag and the tag on the
|
||||
# Viterbi path
|
||||
w_predictions[viterbi[i]], w_predictions[old_best] = (
|
||||
w_predictions[old_best],
|
||||
w_predictions[viterbi[i]],
|
||||
)
|
||||
new_predictions[i] = w_predictions
|
||||
# Remove the BOS and EOS entries from the predictions matrix
|
||||
orig_predictions = new_predictions[1:-1, 0:-2]
|
||||
outputs[0].reshape(orig_predictions.shape)
|
||||
outputs[0].data[...] = orig_predictions
|
||||
|
||||
padded_classes = CRFWithLoss.pad_predictions(
|
||||
classes, self.model.param_init_net, self.model.net, self.num_classes
|
||||
)
|
||||
new_classes = self.model.net.Python(crf_update_predictions_op)(
|
||||
[padded_classes, self.transitions],
|
||||
core.ScopedBlobReference("post_crf_classes"),
|
||||
)
|
||||
return new_classes
|
||||
|
||||
@staticmethod
|
||||
def pad_labels(labels, init_net, net, num_classes):
|
||||
bos_i = num_classes
|
||||
eos_i = num_classes + 1
|
||||
bos_i_b = init_net.ConstantFill([], shape=[1], value=bos_i)
|
||||
eos_i_b = init_net.ConstantFill([], shape=[1], value=eos_i)
|
||||
labels = net.Cast([labels], to="int64")
|
||||
padded_labels, _ = net.Concat([bos_i_b, labels, eos_i_b], axis=0, outputs=2)
|
||||
return padded_labels
|
||||
|
||||
@staticmethod
|
||||
def pad_predictions(predictions, init_net, net, num_classes):
|
||||
# This function will introduce two labels for beginning of sequence
|
||||
# And end of sequence, it will make the necessary udpates to the
|
||||
# the predictions blob
|
||||
|
||||
low_score = -1000.0 # An arbitray very low number
|
||||
b_scores = np.array([[low_score] * num_classes + [0, low_score]]).astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
e_scores = np.array([[low_score] * num_classes + [low_score, 0]]).astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
b_scores = init_net.GivenTensorFill(
|
||||
[], "b_scores", shape=[1, num_classes + 2], values=b_scores
|
||||
)
|
||||
e_scores = init_net.GivenTensorFill(
|
||||
[], "e_scores", shape=[1, num_classes + 2], values=e_scores
|
||||
)
|
||||
|
||||
zero_index = net.ConstantFill([], shape=[1], value=0)
|
||||
length = net.Gather([net.Shape([predictions]), zero_index])
|
||||
length = net.Cast(length, to="int32")
|
||||
t_range = net.LengthsRangeFill(length)
|
||||
padding = net.ConstantFill([t_range], value=low_score)
|
||||
padding = net.ExpandDims(padding, dims=[1])
|
||||
padded_predictions, _ = net.Concat(
|
||||
[predictions, padding, padding], outputs=2, axis=1
|
||||
)
|
||||
padded_predictions_concat, _ = net.Concat(
|
||||
[b_scores, padded_predictions, e_scores], outputs=2, axis=0
|
||||
)
|
||||
return padded_predictions_concat
|
||||
33
caffe2/python/crf_predict.py
Normal file
33
caffe2/python/crf_predict.py
Normal file
@ -0,0 +1,33 @@
|
||||
|
||||
|
||||
import numpy as np
|
||||
from caffe2.python.crf import CRFWithLoss
|
||||
|
||||
|
||||
def crf_update_predictions(model, crf_with_loss, classes):
|
||||
return apply_crf(
|
||||
model.param_init_net,
|
||||
model.net,
|
||||
crf_with_loss.transitions,
|
||||
classes,
|
||||
crf_with_loss.num_classes,
|
||||
)
|
||||
|
||||
|
||||
def apply_crf(init_net, net, transitions, predictions, num_classes):
|
||||
padded_classes = CRFWithLoss.pad_predictions(
|
||||
predictions, init_net, net, num_classes
|
||||
)
|
||||
bestPath = net.ViterbiPath([padded_classes, transitions])
|
||||
new_padded_classes = net.SwapBestPath([padded_classes, bestPath])
|
||||
# Revert the effect of pad_predictions by removing the last two rows and
|
||||
# the last two columns
|
||||
new_classes = net.RemovePadding(
|
||||
[new_padded_classes], padding_width=1, end_padding_width=1
|
||||
)
|
||||
slice_starts = np.array([0, 0]).astype(np.int32)
|
||||
slice_ends = np.array([-1, -3]).astype(np.int32)
|
||||
slice_starts = net.GivenTensorIntFill([], shape=[2], values=slice_starts)
|
||||
slice_ends = net.GivenTensorIntFill([], shape=[2], values=slice_ends)
|
||||
new_classes = net.Slice([new_classes, slice_starts, slice_ends])
|
||||
return new_classes
|
||||
46
caffe2/python/crf_viterbi_test.py
Normal file
46
caffe2/python/crf_viterbi_test.py
Normal file
@ -0,0 +1,46 @@
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import workspace, crf
|
||||
|
||||
from caffe2.python.cnn import CNNModelHelper
|
||||
from caffe2.python.crf_predict import crf_update_predictions
|
||||
from caffe2.python.test_util import TestCase
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import given, settings
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestCrfDecode(TestCase):
|
||||
|
||||
@given(num_tags=st.integers(2, 4), num_words=st.integers(2, 15))
|
||||
@settings(deadline=2000)
|
||||
def test_crf_viterbi(self, num_tags, num_words):
|
||||
model = CNNModelHelper(name='external')
|
||||
predictions = np.random.randn(num_words, num_tags).astype(np.float32)
|
||||
transitions = np.random.uniform(
|
||||
low=-1, high=1, size=(num_tags + 2, num_tags + 2)
|
||||
).astype(np.float32)
|
||||
predictions_blob, transitions_blob = (
|
||||
model.net.AddExternalInputs('predictions', 'crf_transitions')
|
||||
)
|
||||
workspace.FeedBlob(str(transitions_blob), transitions)
|
||||
workspace.FeedBlob(str(predictions_blob), predictions)
|
||||
crf_layer = crf.CRFWithLoss(model, num_tags, transitions_blob)
|
||||
|
||||
updated_predictions = crf_update_predictions(
|
||||
model, crf_layer, predictions_blob
|
||||
)
|
||||
ref_predictions = crf_layer.update_predictions(predictions_blob)
|
||||
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.RunNetOnce(model.net)
|
||||
|
||||
updated_predictions = workspace.FetchBlob(str(updated_predictions))
|
||||
ref_predictions = workspace.FetchBlob(str(ref_predictions))
|
||||
np.testing.assert_allclose(
|
||||
updated_predictions,
|
||||
ref_predictions,
|
||||
atol=1e-4, rtol=1e-4, err_msg='Mismatch in CRF predictions'
|
||||
)
|
||||
2221
caffe2/python/data_parallel_model.py
Normal file
2221
caffe2/python/data_parallel_model.py
Normal file
File diff suppressed because it is too large
Load Diff
1427
caffe2/python/data_parallel_model_test.py
Normal file
1427
caffe2/python/data_parallel_model_test.py
Normal file
File diff suppressed because it is too large
Load Diff
461
caffe2/python/data_workers.py
Normal file
461
caffe2/python/data_workers.py
Normal file
@ -0,0 +1,461 @@
|
||||
## @package data_workers
|
||||
# Module caffe2.python.data_workers
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
'''
|
||||
This module provides a python-land multithreaded data input mechanism
|
||||
for Caffe2 nets.
|
||||
|
||||
Basic usage is as follows:
|
||||
coordinator = data_workers.init_data_input_workers(
|
||||
net,
|
||||
["data", "label"],
|
||||
my_fetch_fun,
|
||||
batch_size=32,
|
||||
input_source_name="train",
|
||||
dont_rebatch=False
|
||||
)
|
||||
...
|
||||
coordinator.start()
|
||||
|
||||
First argument is the Caffe2 net (or model helper), and second argument
|
||||
is list of input blobs that are to be fed.
|
||||
|
||||
Argument 'input_source_name' is used to distinguish different sources of data,
|
||||
such as train or test data. This is to ensure the data does not get mixed up,
|
||||
although two nets would share blobs.
|
||||
|
||||
To do the actual data loading, one defines a "fetcher function"
|
||||
that has call signature
|
||||
my_fetch_fun(worker_id, batch_size)
|
||||
|
||||
Optionally, one can define a "init function" that is called once before
|
||||
threads start, and has call signature:
|
||||
my_init_fun(data_coordinator, global_coordinator)
|
||||
|
||||
If dont_rebatch is set to True, the data input is not batched into equal sized
|
||||
chunks but data directly provided by fetchers is used.
|
||||
|
||||
'batch_columns' can be used to specify which dimension is the batch dimension,
|
||||
for each of the inputs. Default is 0 for all iputs.
|
||||
|
||||
'timeout' is the timeout in seconds after which if no data is available, the
|
||||
net will fail (default 600s = 10 mins).
|
||||
|
||||
This function returns a list of numpy arrays corresponding to the different
|
||||
input blobs. In the example above, it would return two arrays, one for the
|
||||
data blob and another for the labels. These arrays can have arbitrary number
|
||||
of elements (i.e they do not need to match the batch size). The batch size
|
||||
is provided for the function as a hint only.
|
||||
|
||||
For example, fetcher function could download images from a remote service or
|
||||
load random images from a directory on a file system.
|
||||
|
||||
For a dummy example, see the data_workers_test unit test.
|
||||
|
||||
Note that for data_parallel_models, init_data_input_workers will be called
|
||||
for each GPU. Note that the 'coordinator' returned by the function is same
|
||||
each time.
|
||||
'''
|
||||
|
||||
import queue as Queue
|
||||
from itertools import chain
|
||||
import logging
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from caffe2.python import workspace, core, scope, utils
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python.parallel_workers import Metrics, State, \
|
||||
WorkerCoordinator, GlobalWorkerCoordinator, Worker, run_worker
|
||||
|
||||
log = logging.getLogger("data_workers")
|
||||
log.setLevel(logging.INFO)
|
||||
LOG_INT_SECS = 60
|
||||
|
||||
|
||||
def get_worker_ids(num_workers):
|
||||
return list(range(0, num_workers))
|
||||
|
||||
|
||||
def init_data_input_workers(
|
||||
net,
|
||||
input_blob_names,
|
||||
fetch_fun,
|
||||
batch_size,
|
||||
num_worker_threads=2,
|
||||
input_source_name="train",
|
||||
max_buffered_batches=800,
|
||||
init_fun=None,
|
||||
external_loggers=None,
|
||||
dont_rebatch=False,
|
||||
batch_columns=None,
|
||||
timeout=600
|
||||
):
|
||||
global global_coordinator
|
||||
device_option = scope.CurrentDeviceScope()
|
||||
if (device_option is None):
|
||||
device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
|
||||
|
||||
metrics = Metrics(external_loggers)
|
||||
batch_feeder = BatchFeeder(
|
||||
net,
|
||||
input_blob_names,
|
||||
batch_size,
|
||||
device_option,
|
||||
scope.CurrentNameScope(),
|
||||
input_source_name,
|
||||
global_coordinator.get_queue(input_source_name, max_buffered_batches),
|
||||
metrics,
|
||||
dont_rebatch,
|
||||
batch_columns,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Launch fetch worker threads
|
||||
worker_ids = [
|
||||
global_coordinator.get_new_worker_id()
|
||||
for i in range(num_worker_threads)
|
||||
]
|
||||
|
||||
# Create coordinator object
|
||||
coordinator = WorkerCoordinator(
|
||||
input_source_name, worker_ids, init_fun, batch_feeder)
|
||||
|
||||
workers = [
|
||||
threading.Thread(
|
||||
target=run_worker,
|
||||
name="data_workers fetcher id {}".format(worker_id),
|
||||
args=[coordinator,
|
||||
DataWorker(coordinator, worker_id, fetch_fun, metrics,
|
||||
batch_size, batch_feeder)],
|
||||
) for worker_id in worker_ids
|
||||
]
|
||||
|
||||
workers.append(threading.Thread(
|
||||
target=enqueuer,
|
||||
name="Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
|
||||
args=[coordinator, batch_feeder]))
|
||||
coordinator._workers = workers
|
||||
global_coordinator.add(coordinator)
|
||||
|
||||
return global_coordinator
|
||||
|
||||
|
||||
class BatchFeeder(State):
|
||||
def __init__(self, net, input_blob_names, batch_size,
|
||||
device_option, namescope, input_source_name, queue,
|
||||
metrics, dont_rebatch, batch_columns, timeout=600):
|
||||
self._counter = 0
|
||||
self._input_blob_names = input_blob_names
|
||||
self._batch_size = batch_size
|
||||
self._internal_queue = queue
|
||||
self._queues = []
|
||||
self._device_option = device_option
|
||||
self._namescope = namescope
|
||||
self._timeout = timeout
|
||||
self._input_source_name = input_source_name
|
||||
self._c2_queue_capacity = 4
|
||||
self._create_caffe2_queues(net)
|
||||
self._create_caffe2_ops(net)
|
||||
self._inputs = 0
|
||||
self._prev_seconds = 0
|
||||
self._last_warning = time.time()
|
||||
self._dont_rebatch = dont_rebatch
|
||||
self._init_scratch()
|
||||
self._metrics = metrics
|
||||
|
||||
if batch_columns is None:
|
||||
batch_columns = [0 for _ in input_blob_names]
|
||||
self._batch_columns = batch_columns
|
||||
|
||||
def start(self):
|
||||
self._inputs = 0
|
||||
self._prev_seconds = time.time()
|
||||
|
||||
def stop(self):
|
||||
try:
|
||||
for q in self._queues:
|
||||
workspace.RunOperatorOnce(
|
||||
core.CreateOperator("CloseBlobsQueue", [q], [])
|
||||
)
|
||||
finally:
|
||||
self._log_inputs_per_interval(0, force=True)
|
||||
|
||||
def cleanup(self):
|
||||
utils.ResetBlobs(self._scratch_blob.values())
|
||||
utils.ResetBlobs(self._scratch_status.values())
|
||||
|
||||
def _get(self, data_input_coordinator):
|
||||
start_time = time.time()
|
||||
last_warning = time.time()
|
||||
while data_input_coordinator.is_active():
|
||||
try:
|
||||
return self._internal_queue.get(block=True, timeout=0.5)
|
||||
except Queue.Empty:
|
||||
if time.time() - last_warning > 10.0:
|
||||
log.warning("** Data input is slow: (still) no data in {} secs.".format(
|
||||
time.time() - start_time))
|
||||
last_warning = time.time()
|
||||
continue
|
||||
return None
|
||||
|
||||
def _validate_chunk(self, chunk):
|
||||
if chunk is None:
|
||||
log.warning("Fetcher function returned None")
|
||||
return False
|
||||
|
||||
assert len(chunk) == len(self._input_blob_names), \
|
||||
"Expecting data blob for each input"
|
||||
for d in chunk:
|
||||
assert isinstance(d, np.ndarray), \
|
||||
"Fetcher function must return a numpy array"
|
||||
if not self._dont_rebatch:
|
||||
j = 1
|
||||
for d in chunk[1:]:
|
||||
assert d.shape[self._batch_columns[j]] == \
|
||||
chunk[0].shape[self._batch_columns[0]], \
|
||||
"Each returned input must have equal number of samples"
|
||||
j += 1
|
||||
|
||||
if len(chunk) == 0:
|
||||
log.warning("Worker provided zero length input")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def put(self, chunk, data_input_coordinator):
|
||||
if not self._validate_chunk(chunk):
|
||||
return
|
||||
|
||||
while data_input_coordinator.is_active():
|
||||
try:
|
||||
qsize = self._internal_queue.qsize()
|
||||
if qsize < 2 and (time.time() - self._last_warning) > LOG_INT_SECS:
|
||||
log.warning("Warning, data loading lagging behind: " +
|
||||
"queue size={}, name={}".format(qsize, self._input_source_name))
|
||||
self._last_warning = time.time()
|
||||
self._counter += 1
|
||||
self._internal_queue.put(chunk, block=True, timeout=0.5)
|
||||
self._log_inputs_per_interval(chunk[0].shape[0])
|
||||
return
|
||||
except Queue.Full:
|
||||
log.debug("Queue full: stalling fetchers...")
|
||||
continue
|
||||
|
||||
def _enqueue_batch_direct(self, data_input_coordinator):
|
||||
data = self._get(data_input_coordinator)
|
||||
if data is None:
|
||||
return
|
||||
if data_input_coordinator.is_active():
|
||||
for b, q, c in zip(self._input_blob_names, self._queues, data):
|
||||
self._enqueue(b, q, c)
|
||||
|
||||
def _enqueue_batch(self, data_input_coordinator):
|
||||
'''
|
||||
This pulls data from the python-side queue and collects them
|
||||
into batch-sized pieces, unless dont_rebatch is set to true.
|
||||
'''
|
||||
if self._dont_rebatch:
|
||||
self._enqueue_batch_direct(data_input_coordinator)
|
||||
return
|
||||
|
||||
cur_batch = [np.array([]) for d in self._input_blob_names]
|
||||
first_batch_col = self._batch_columns[0]
|
||||
|
||||
# Collect data until we have a full batch size
|
||||
while (
|
||||
cur_batch[0].shape[0] == 0 or
|
||||
cur_batch[0].shape[first_batch_col] < self._batch_size
|
||||
) and data_input_coordinator.is_active():
|
||||
chunk = self._get(data_input_coordinator)
|
||||
if chunk is None:
|
||||
continue
|
||||
|
||||
for j, chunk_elem in enumerate(chunk):
|
||||
if cur_batch[j].shape[0] == 0:
|
||||
cur_batch[j] = chunk_elem.copy()
|
||||
else:
|
||||
cur_batch[j] = np.append(
|
||||
cur_batch[j], chunk_elem, axis=self._batch_columns[j]
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Return data over the batch size back to queue
|
||||
if cur_batch[0].shape[0] > 0 and cur_batch[0].shape[
|
||||
first_batch_col
|
||||
] > self._batch_size:
|
||||
leftover = []
|
||||
trimmed_batch = []
|
||||
for j, b in enumerate(cur_batch):
|
||||
[c, l] = np.split(
|
||||
b, [self._batch_size], axis=self._batch_columns[j]
|
||||
)
|
||||
leftover.append(l)
|
||||
trimmed_batch.append(c)
|
||||
cur_batch = trimmed_batch
|
||||
try:
|
||||
self._internal_queue.put(leftover, block=False)
|
||||
except Queue.Full:
|
||||
pass
|
||||
|
||||
assert cur_batch[0].shape[first_batch_col] == self._batch_size
|
||||
|
||||
if data_input_coordinator.is_active():
|
||||
for b, q, c in zip(
|
||||
self._input_blob_names, self._queues, cur_batch
|
||||
):
|
||||
self._enqueue(b, q, c)
|
||||
finally:
|
||||
self._metrics.put_metric('enqueue_time', time.time() - start_time)
|
||||
|
||||
def _init_scratch(self):
|
||||
self._scratch_blob = {}
|
||||
self._scratch_status = {}
|
||||
for blob_name in self._input_blob_names:
|
||||
scratch_name = self._namescope + blob_name + \
|
||||
"_scratch_" + self._input_source_name
|
||||
self._scratch_blob[blob_name] = core.BlobReference(scratch_name)
|
||||
self._scratch_status[blob_name] = core.BlobReference(
|
||||
scratch_name + "_status"
|
||||
)
|
||||
|
||||
# Feed empty arrays to the scratch blobs here, so that there won't be
|
||||
# race conditions when calling FeedBlob (which calls wworkspace
|
||||
# CreateBlob()) from enqueue threads
|
||||
for b in chain(
|
||||
self._scratch_blob.values(), self._scratch_status.values()
|
||||
):
|
||||
workspace.FeedBlob(
|
||||
b,
|
||||
np.array([]).astype(np.float32),
|
||||
device_option=self._device_option,
|
||||
)
|
||||
|
||||
def _enqueue(self, blob_name, queue, data_arr):
|
||||
'''
|
||||
Enqueue the correctly sized batch arrays to Caffe2's queue.
|
||||
'''
|
||||
workspace.FeedBlob(
|
||||
self._scratch_blob[blob_name],
|
||||
data_arr,
|
||||
device_option=self._device_option
|
||||
)
|
||||
|
||||
op = core.CreateOperator(
|
||||
"SafeEnqueueBlobs",
|
||||
[queue, self._scratch_blob[blob_name]],
|
||||
[self._scratch_blob[blob_name], self._scratch_status[blob_name]],
|
||||
device_option=self._device_option
|
||||
)
|
||||
workspace.RunOperatorOnce(op)
|
||||
|
||||
def _create_caffe2_queues(self, net):
|
||||
'''
|
||||
Creates queues on caffe2 side
|
||||
'''
|
||||
def create_queue(queue_name, num_blobs, capacity):
|
||||
workspace.RunOperatorOnce(
|
||||
core.CreateOperator(
|
||||
"CreateBlobsQueue",
|
||||
[], [queue_name],
|
||||
num_blobs=1,
|
||||
capacity=capacity))
|
||||
return core.ScopedBlobReference(queue_name)
|
||||
|
||||
for blob_name in self._input_blob_names:
|
||||
qname = blob_name + "_c2queue" + "_" + self._input_source_name
|
||||
q = create_queue(
|
||||
qname, num_blobs=1, capacity=self._c2_queue_capacity
|
||||
)
|
||||
self._queues.append(q)
|
||||
|
||||
def _create_caffe2_ops(self, net):
|
||||
'''
|
||||
Creates dequeue-ops on caffe2 side
|
||||
'''
|
||||
for q, blob_name in zip(self._queues, self._input_blob_names):
|
||||
# Add operator to the Caffe2 network to dequeue
|
||||
net.DequeueBlobs(q, blob_name, timeout_secs=float(self._timeout))
|
||||
|
||||
def _log_inputs_per_interval(self, inputs, force=False):
|
||||
self._inputs += inputs
|
||||
current_seconds = time.time()
|
||||
delta_seconds = current_seconds - self._prev_seconds
|
||||
if delta_seconds >= LOG_INT_SECS or force:
|
||||
inputs_per_sec = int(self._inputs / delta_seconds)
|
||||
qsize = self._internal_queue.qsize()
|
||||
log.info("{}/{}: {} inputs/sec".format(
|
||||
self._input_source_name,
|
||||
self._namescope,
|
||||
inputs_per_sec,
|
||||
))
|
||||
log.info("-- queue: {} batches".format(qsize))
|
||||
# log and reset perf metrics
|
||||
self._metrics.put_metric(
|
||||
'inputs_per_sec', inputs_per_sec, False)
|
||||
self._metrics.put_metric('queue_size', qsize, False)
|
||||
self._metrics.put_metric(
|
||||
'time_elapsed', delta_seconds, False)
|
||||
self._metrics.log_metrics()
|
||||
self._metrics.reset_metrics()
|
||||
self._inputs = 0
|
||||
self._prev_seconds = current_seconds
|
||||
|
||||
|
||||
class GlobalCoordinator(GlobalWorkerCoordinator):
|
||||
def __init__(self):
|
||||
GlobalWorkerCoordinator.__init__(self)
|
||||
self._queues = {}
|
||||
|
||||
def get_queue(self, queue_name, max_buffered_batches):
|
||||
assert isinstance(max_buffered_batches, int)
|
||||
if queue_name not in self._queues:
|
||||
self._queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
|
||||
return self._queues[queue_name]
|
||||
|
||||
def reset_data_input(self, namescope, name, net, batch_size):
|
||||
log.info("Reset data input {}, batch size {}: ".format(name, batch_size))
|
||||
for c in self._coordinators:
|
||||
if c._worker_name == name and c._state._namescope == namescope:
|
||||
c._state._batch_size = batch_size
|
||||
c._state._create_caffe2_ops(net)
|
||||
|
||||
|
||||
class DataWorker(Worker):
|
||||
def __init__(
|
||||
self,
|
||||
coordinator,
|
||||
worker_id,
|
||||
worker_fun,
|
||||
metrics,
|
||||
batch_size,
|
||||
batch_feeder
|
||||
):
|
||||
Worker.__init__(self, coordinator, worker_id, worker_fun=worker_fun,
|
||||
metrics=metrics)
|
||||
self._batch_size = batch_size
|
||||
self._batch_feeder = batch_feeder
|
||||
|
||||
def run(self):
|
||||
input_data = self._worker_fun(self._worker_id, self._batch_size)
|
||||
|
||||
self._batch_feeder.put(input_data, self._coordinator)
|
||||
|
||||
def finish(self):
|
||||
self._metrics.put_metric(
|
||||
'fetcher_time', time.time() - self._start_time)
|
||||
|
||||
|
||||
global_coordinator = GlobalCoordinator()
|
||||
|
||||
|
||||
def enqueuer(coordinator, batch_feeder):
|
||||
while coordinator.is_active():
|
||||
batch_feeder._enqueue_batch(coordinator)
|
||||
196
caffe2/python/data_workers_test.py
Normal file
196
caffe2/python/data_workers_test.py
Normal file
@ -0,0 +1,196 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import time
|
||||
|
||||
from caffe2.python import workspace, model_helper
|
||||
from caffe2.python import timeout_guard
|
||||
import caffe2.python.data_workers as data_workers
|
||||
|
||||
|
||||
def dummy_fetcher(fetcher_id, batch_size):
|
||||
# Create random amount of values
|
||||
n = np.random.randint(64) + 1
|
||||
data = np.zeros((n, 3))
|
||||
labels = []
|
||||
for j in range(n):
|
||||
data[j, :] *= (j + fetcher_id)
|
||||
labels.append(data[j, 0])
|
||||
|
||||
return [np.array(data), np.array(labels)]
|
||||
|
||||
|
||||
def dummy_fetcher_rnn(fetcher_id, batch_size):
|
||||
# Hardcoding some input blobs
|
||||
T = 20
|
||||
N = batch_size
|
||||
D = 33
|
||||
data = np.random.rand(T, N, D)
|
||||
label = np.random.randint(N, size=(T, N))
|
||||
seq_lengths = np.random.randint(N, size=(N))
|
||||
return [data, label, seq_lengths]
|
||||
|
||||
|
||||
class DataWorkersTest(unittest.TestCase):
|
||||
|
||||
def testNonParallelModel(self):
|
||||
workspace.ResetWorkspace()
|
||||
|
||||
model = model_helper.ModelHelper(name="test")
|
||||
old_seq_id = data_workers.global_coordinator._fetcher_id_seq
|
||||
coordinator = data_workers.init_data_input_workers(
|
||||
model,
|
||||
["data", "label"],
|
||||
dummy_fetcher,
|
||||
32,
|
||||
2,
|
||||
input_source_name="unittest"
|
||||
)
|
||||
new_seq_id = data_workers.global_coordinator._fetcher_id_seq
|
||||
self.assertEqual(new_seq_id, old_seq_id + 2)
|
||||
|
||||
coordinator.start()
|
||||
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.CreateNet(model.net)
|
||||
|
||||
for _i in range(500):
|
||||
with timeout_guard.CompleteInTimeOrDie(5):
|
||||
workspace.RunNet(model.net.Proto().name)
|
||||
|
||||
data = workspace.FetchBlob("data")
|
||||
labels = workspace.FetchBlob("label")
|
||||
|
||||
self.assertEqual(data.shape[0], labels.shape[0])
|
||||
self.assertEqual(data.shape[0], 32)
|
||||
|
||||
for j in range(32):
|
||||
self.assertEqual(labels[j], data[j, 0])
|
||||
self.assertEqual(labels[j], data[j, 1])
|
||||
self.assertEqual(labels[j], data[j, 2])
|
||||
|
||||
coordinator.stop_coordinator("unittest")
|
||||
self.assertEqual(coordinator._coordinators, [])
|
||||
|
||||
def testRNNInput(self):
|
||||
workspace.ResetWorkspace()
|
||||
model = model_helper.ModelHelper(name="rnn_test")
|
||||
old_seq_id = data_workers.global_coordinator._fetcher_id_seq
|
||||
coordinator = data_workers.init_data_input_workers(
|
||||
model,
|
||||
["data1", "label1", "seq_lengths1"],
|
||||
dummy_fetcher_rnn,
|
||||
32,
|
||||
2,
|
||||
dont_rebatch=False,
|
||||
batch_columns=[1, 1, 0],
|
||||
)
|
||||
new_seq_id = data_workers.global_coordinator._fetcher_id_seq
|
||||
self.assertEqual(new_seq_id, old_seq_id + 2)
|
||||
|
||||
coordinator.start()
|
||||
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.CreateNet(model.net)
|
||||
|
||||
while coordinator._coordinators[0]._state._inputs < 100:
|
||||
time.sleep(0.01)
|
||||
|
||||
# Run a couple of rounds
|
||||
workspace.RunNet(model.net.Proto().name)
|
||||
workspace.RunNet(model.net.Proto().name)
|
||||
|
||||
# Wait for the enqueue thread to get blocked
|
||||
time.sleep(0.2)
|
||||
|
||||
# We don't dequeue on caffe2 side (as we don't run the net)
|
||||
# so the enqueue thread should be blocked.
|
||||
# Let's now shutdown and see it succeeds.
|
||||
self.assertTrue(coordinator.stop())
|
||||
|
||||
@unittest.skip("Test is flaky: https://github.com/pytorch/pytorch/issues/9064")
|
||||
def testInputOrder(self):
|
||||
#
|
||||
# Create two models (train and validation) with same input blobs
|
||||
# names and ensure that both will get the data in correct order
|
||||
#
|
||||
workspace.ResetWorkspace()
|
||||
self.counters = {0: 0, 1: 1}
|
||||
|
||||
def dummy_fetcher_rnn_ordered1(fetcher_id, batch_size):
|
||||
# Hardcoding some input blobs
|
||||
T = 20
|
||||
N = batch_size
|
||||
D = 33
|
||||
data = np.zeros((T, N, D))
|
||||
data[0][0][0] = self.counters[fetcher_id]
|
||||
label = np.random.randint(N, size=(T, N))
|
||||
label[0][0] = self.counters[fetcher_id]
|
||||
seq_lengths = np.random.randint(N, size=(N))
|
||||
seq_lengths[0] = self.counters[fetcher_id]
|
||||
self.counters[fetcher_id] += 1
|
||||
return [data, label, seq_lengths]
|
||||
|
||||
workspace.ResetWorkspace()
|
||||
model = model_helper.ModelHelper(name="rnn_test_order")
|
||||
|
||||
coordinator = data_workers.init_data_input_workers(
|
||||
model,
|
||||
input_blob_names=["data2", "label2", "seq_lengths2"],
|
||||
fetch_fun=dummy_fetcher_rnn_ordered1,
|
||||
batch_size=32,
|
||||
max_buffered_batches=1000,
|
||||
num_worker_threads=1,
|
||||
dont_rebatch=True,
|
||||
input_source_name='train'
|
||||
)
|
||||
coordinator.start()
|
||||
|
||||
val_model = model_helper.ModelHelper(name="rnn_test_order_val")
|
||||
coordinator1 = data_workers.init_data_input_workers(
|
||||
val_model,
|
||||
input_blob_names=["data2", "label2", "seq_lengths2"],
|
||||
fetch_fun=dummy_fetcher_rnn_ordered1,
|
||||
batch_size=32,
|
||||
max_buffered_batches=1000,
|
||||
num_worker_threads=1,
|
||||
dont_rebatch=True,
|
||||
input_source_name='val'
|
||||
)
|
||||
coordinator1.start()
|
||||
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
workspace.CreateNet(model.net)
|
||||
workspace.CreateNet(val_model.net)
|
||||
|
||||
while coordinator._coordinators[0]._state._inputs < 900:
|
||||
time.sleep(0.01)
|
||||
|
||||
with timeout_guard.CompleteInTimeOrDie(5):
|
||||
for m in (model, val_model):
|
||||
print(m.net.Proto().name)
|
||||
workspace.RunNet(m.net.Proto().name)
|
||||
last_data = workspace.FetchBlob('data2')[0][0][0]
|
||||
last_lab = workspace.FetchBlob('label2')[0][0]
|
||||
last_seq = workspace.FetchBlob('seq_lengths2')[0]
|
||||
|
||||
# Run few rounds
|
||||
for _i in range(10):
|
||||
workspace.RunNet(m.net.Proto().name)
|
||||
data = workspace.FetchBlob('data2')[0][0][0]
|
||||
lab = workspace.FetchBlob('label2')[0][0]
|
||||
seq = workspace.FetchBlob('seq_lengths2')[0]
|
||||
self.assertEqual(data, last_data + 1)
|
||||
self.assertEqual(lab, last_lab + 1)
|
||||
self.assertEqual(seq, last_seq + 1)
|
||||
last_data = data
|
||||
last_lab = lab
|
||||
last_seq = seq
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
self.assertTrue(coordinator.stop())
|
||||
635
caffe2/python/dataio.py
Normal file
635
caffe2/python/dataio.py
Normal file
@ -0,0 +1,635 @@
|
||||
## @package dataio
|
||||
# Module caffe2.python.dataio
|
||||
"""
|
||||
Defines the base interface for reading and writing operations.
|
||||
|
||||
Readers/Writers are objects that produce operations that read/write sequences
|
||||
of data. Each operation reads or writes a list of BlobReferences.
|
||||
|
||||
Readers and Writers must be implemented such that read and write operations
|
||||
are atomic and thread safe.
|
||||
|
||||
Examples of possible Readers and Writers:
|
||||
QueueReader, QueueWriter,
|
||||
DatasetReader, DatasetWriter,
|
||||
|
||||
See `dataset.py` for an example of implementation.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python import core
|
||||
from caffe2.python.schema import Field, Struct, from_blob_list
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
class Reader:
|
||||
"""
|
||||
Reader is an abstract class to be implemented in order to provide
|
||||
operations capable of iterating through a dataset or stream of data.
|
||||
|
||||
A Reader must implement at least one operation, `read`, which
|
||||
adds operations to a net that read the next batch of data. Readers can
|
||||
optionally support the `reset` operation, which is useful when multiple
|
||||
passes over the data are required.
|
||||
"""
|
||||
def __init__(self, schema=None):
|
||||
if schema is not None:
|
||||
assert isinstance(schema, Field)
|
||||
self._schema = schema
|
||||
|
||||
def schema(self):
|
||||
assert self._schema is not None, 'Schema not provided for this reader.'
|
||||
return self._schema
|
||||
|
||||
def _set_schema(self, schema):
|
||||
self._schema = schema
|
||||
|
||||
def setup_ex(self, init_net, finish_net):
|
||||
"""Setup nets to run at task initialization and cleanup time.
|
||||
|
||||
Args:
|
||||
init_net: A net invoked at task init time.
|
||||
finish_net: A net invoked at task cleanup time.
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_ex(self, local_init_net, local_finish_net):
|
||||
read_net = core.Net('reader_body')
|
||||
return ([read_net], ) + self.read(read_net)
|
||||
|
||||
def read_record_ex(self, local_init_net, local_finish_net):
|
||||
nets, should_stop, fields = self.read_ex(
|
||||
local_init_net, local_finish_net)
|
||||
if self._schema:
|
||||
fields = from_blob_list(self._schema, fields)
|
||||
return nets, should_stop, fields
|
||||
|
||||
def read(self, read_net):
|
||||
"""Append operations to read_net that will read a batch from the
|
||||
underlying data soruce.
|
||||
|
||||
Operations added to `read_net` must be thread safe and atomic, that is,
|
||||
it should be possible to clone `read_net` and run multiple instances of
|
||||
it in parallel.
|
||||
|
||||
Args:
|
||||
read_net: the net that will be appended with read operations
|
||||
|
||||
Returns:
|
||||
A tuple (should_stop, fields), with:
|
||||
should_stop: BlobReference pointing to a boolean scalar
|
||||
blob that indicates whether the read operation
|
||||
was succesfull or whether the end of data has
|
||||
been reached.
|
||||
fields: A tuple of BlobReference containing the latest batch
|
||||
of data that was read.
|
||||
"""
|
||||
raise NotImplementedError('Readers must implement `read`.')
|
||||
|
||||
def reset(self, net):
|
||||
"""Append operations to `net` that will reset the reader.
|
||||
|
||||
This can be used to read the data multiple times.
|
||||
Not all readers support this operation.
|
||||
"""
|
||||
raise NotImplementedError('This reader cannot be resetted.')
|
||||
|
||||
def read_record(self, read_net):
|
||||
should_stop, fields = self.read(read_net)
|
||||
if self._schema:
|
||||
fields = from_blob_list(self._schema, fields)
|
||||
return should_stop, fields
|
||||
|
||||
def execution_step(self, reader_net_name=None, external_should_stop=None):
|
||||
"""Create an execution step with a net containing read operators.
|
||||
|
||||
The execution step will contain a `stop_blob` that knows how to stop
|
||||
the execution loop when end of data was reached.
|
||||
|
||||
E.g.:
|
||||
|
||||
read_step, fields = reader.execution_step()
|
||||
consume_net = core.Net('consume')
|
||||
consume_net.Print(fields[0], [])
|
||||
p = core.Plan('reader')
|
||||
p.AddStep(read_step.AddNet(consume_net))
|
||||
core.RunPlan(p)
|
||||
|
||||
Args:
|
||||
reader_net_name: (optional) the name of the reader_net to be
|
||||
created. The execution step will
|
||||
be named accordingly.
|
||||
|
||||
Returns:
|
||||
A tuple (read_step, fields), with:
|
||||
read_step: A newly created execution step containing a net with
|
||||
read operations. The step will have `stop_blob` set,
|
||||
in order to stop the loop on end of data.
|
||||
fields: A tuple of BlobReference containing the latest batch
|
||||
of data that was read.
|
||||
"""
|
||||
reader_net = core.Net(reader_net_name or 'reader')
|
||||
should_stop, fields = self.read_record(reader_net)
|
||||
if external_should_stop is not None:
|
||||
should_stop = reader_net.Or([external_should_stop, should_stop])
|
||||
read_step = core.execution_step(
|
||||
'{}_step'.format(reader_net_name),
|
||||
reader_net,
|
||||
should_stop_blob=should_stop)
|
||||
return (read_step, fields)
|
||||
|
||||
|
||||
class Writer:
|
||||
"""
|
||||
Writer is an abstract class to be implemented in order to provide
|
||||
operations capable of feeding a data stream or a dataset.
|
||||
|
||||
A Writer must implement 2 operations:
|
||||
`write`, which adds operations to a net that write the write batch of
|
||||
data, and `commit`, which adds operations to a net in order to indicate
|
||||
that no more data will be written.
|
||||
"""
|
||||
_schema = None
|
||||
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
def write(self, writer_net, fields):
|
||||
"""Add operations to `writer_net` that write the next batch of data.
|
||||
|
||||
Operations added to the net must be thread-safe and unique, that is:
|
||||
multiple writers must be able to write to the dataset in parallel.
|
||||
|
||||
Args:
|
||||
fields: a tuple of BlobReference containing the batch of data to
|
||||
write.
|
||||
"""
|
||||
raise NotImplementedError('Writers must implement write.')
|
||||
|
||||
def write_record(self, writer_net, fields):
|
||||
if isinstance(fields, Field):
|
||||
self._schema = fields
|
||||
fields = fields.field_blobs()
|
||||
self.write(writer_net, fields)
|
||||
|
||||
def setup_ex(self, init_net, finish_net):
|
||||
"""Experimental, don't use yet"""
|
||||
self.commit(finish_net)
|
||||
|
||||
def write_ex(self, fields, local_init_net, local_finish_net, stop_blob):
|
||||
"""Experimental extension to the interface. Don't use yet"""
|
||||
write_net = core.Net('write_net')
|
||||
self.write(write_net, fields)
|
||||
return [write_net]
|
||||
|
||||
def write_record_ex(
|
||||
self, fields, local_init_net, local_finish_net, stop_blob=None):
|
||||
"""Experimental extension to the interface. Don't use yet."""
|
||||
if isinstance(fields, Field):
|
||||
self._schema = fields
|
||||
fields = fields.field_blobs()
|
||||
if stop_blob is None:
|
||||
stop_blob = local_init_net.NextName("dequeue_status")
|
||||
write_nets = self.write_ex(
|
||||
fields, local_init_net, local_finish_net, stop_blob)
|
||||
return (write_nets, stop_blob)
|
||||
|
||||
def commit(self, finish_net):
|
||||
"""Add operations to `finish_net` that signal end of data.
|
||||
|
||||
This must be implemented by all Writers, but may be no-op for some
|
||||
of them.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ReaderBuilder:
|
||||
""" Allow usage of a reader in distributed fashion. """
|
||||
def schema(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def setup(self, **kwargs):
|
||||
"""
|
||||
Optionally, perform one-time setup before calling new_reader().
|
||||
Subclass should make sure this function is only called once.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def new_reader(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PipedReaderBuilder(ReaderBuilder):
|
||||
"""ReaderBuilder that modifies underlying builder by calling `piper`
|
||||
function on each new reader produced, and return the result of
|
||||
the function. This way, it is possible to append data processing
|
||||
pipelines that will be replicated for each reader that gets created.
|
||||
|
||||
E.g.:
|
||||
|
||||
PipedReaderBuilder(
|
||||
ReaderBuilder(...),
|
||||
lambda reader: pipe(reader, processor=my_proc))
|
||||
"""
|
||||
|
||||
def __init__(self, builder, piper):
|
||||
self._builder = builder
|
||||
self._piper = piper
|
||||
|
||||
def schema(self):
|
||||
return self._builder.schema()
|
||||
|
||||
def setup(self, **kwargs):
|
||||
return self._builder.setup(**kwargs)
|
||||
|
||||
def new_reader(self, **kwargs):
|
||||
# Passing everything down since you could wrap a PipedReaderBuilder in
|
||||
# another PipedReaderBuilder
|
||||
output = self._piper(
|
||||
reader=self._builder.new_reader(**kwargs),
|
||||
**kwargs
|
||||
)
|
||||
return output if isinstance(output, Reader) else output.reader()
|
||||
|
||||
|
||||
class Pipe:
|
||||
def __init__(self, schema=None, obj_key=None):
|
||||
self._num_writers = 0
|
||||
self._num_readers = 0
|
||||
self._schema = schema
|
||||
self._obj_key = obj_key
|
||||
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
def setup(self, global_init_net):
|
||||
pass
|
||||
|
||||
def reader(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def writer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def num_readers(self):
|
||||
return self._num_readers
|
||||
|
||||
def num_writers(self):
|
||||
return self._num_writers
|
||||
|
||||
def _new_writer(self, writer_schema, writer_init_net):
|
||||
if writer_schema is not None and self._schema is None:
|
||||
self._schema = writer_schema
|
||||
self._num_writers += 1
|
||||
if self._obj_key is not None:
|
||||
writer_init_net.add_attribute(self._obj_key, self)
|
||||
|
||||
def _new_reader(self, reader_init_net):
|
||||
self._num_readers += 1
|
||||
if self._obj_key is not None:
|
||||
reader_init_net.add_attribute(self._obj_key, self)
|
||||
|
||||
|
||||
class CounterReader(Reader):
|
||||
""" Reader that produces increasing integers. """
|
||||
def __init__(self):
|
||||
Reader.__init__(self, schema=Struct(('iter', np.int64)))
|
||||
self.counter = None
|
||||
self.should_stop = None
|
||||
|
||||
def setup_ex(self, global_init_net, global_finish_net):
|
||||
if self.counter is None:
|
||||
self.counter = global_init_net.CreateCounter([], init_count=0)
|
||||
self.should_stop = global_init_net.ConstantFill(
|
||||
[], shape=[], dtype=core.DataType.BOOL, value=False)
|
||||
|
||||
def read_ex(self, local_init_net, local_finish_net):
|
||||
count_net = core.Net('limited_reader_counter')
|
||||
value = count_net.CountUp([self.counter], 1)
|
||||
return [count_net], self.should_stop, [value]
|
||||
|
||||
|
||||
class ReaderWithLimitBase(Reader):
|
||||
"""Abstract Reader constrained by certain conditions.
|
||||
|
||||
Base class for Reader classes which check for certain conditions to stop
|
||||
further processing (e.g. max number of iterations or time limit).
|
||||
Also produces a boolean blob (data_finished) that can be used to see if
|
||||
the reader exausted all input data (true) or stopped for another reason
|
||||
(false).
|
||||
"""
|
||||
|
||||
def __init__(self, reader):
|
||||
Reader.__init__(self, schema=reader._schema)
|
||||
self.reader = reader
|
||||
self.net = core.Net('reader_with_limit')
|
||||
self._data_finished = self.net.AddExternalInput(
|
||||
self.net.NextName('data_finished'))
|
||||
self.should_stop = None
|
||||
|
||||
def setup_ex(self, global_init_net, global_finish_net):
|
||||
global_init_net.ConstantFill(
|
||||
[], [self._data_finished],
|
||||
shape=[], value=False, dtype=core.DataType.BOOL)
|
||||
self.reader.setup_ex(global_init_net, global_finish_net)
|
||||
self.setup_limiter(global_init_net, global_finish_net)
|
||||
|
||||
def read_ex(self, local_init_net, local_finish_net):
|
||||
"""Reads from an underlying Reader class, but may stop due to additional
|
||||
constraints.
|
||||
|
||||
Build and return network(s) to read data from a Reader with
|
||||
additional constraints, depending on which derived class is used.
|
||||
Derived classes implement setup_limited and check_limiter_condition
|
||||
which determine the nature of the constraint imposed on the reader,
|
||||
e.g. iteration limits or time limit.
|
||||
|
||||
Args:
|
||||
local_init_net: A net invoked at task instance init time (Once per
|
||||
parallel thread).
|
||||
local_finish_net: A net invoked at task instance cleanup time (Once
|
||||
per parallel thread).
|
||||
"""
|
||||
|
||||
# Check if limiting constraint is met.
|
||||
stop_condition_net = core.Net('limited_reader_condition')
|
||||
should_stop = self.check_limiter_condition(stop_condition_net)
|
||||
|
||||
# Call original reader.
|
||||
nets, local_data_finished, fields = self.reader.read_ex(
|
||||
local_init_net, local_finish_net)
|
||||
self._set_schema(self.reader._schema)
|
||||
|
||||
# Check if original reader is done.
|
||||
check_done_net = core.Net('limited_reader_post')
|
||||
# Copy to the same blob as the counter output to trigger reader
|
||||
# stopping - this is ok because execution will check should_stop_blob
|
||||
# after every single operation, so it has already been checked on this
|
||||
# iteration by this point.
|
||||
check_done_net.Copy(local_data_finished, should_stop)
|
||||
# Update externally-accessible flag indicating if reader is done
|
||||
check_done_net.Or([self._data_finished, local_data_finished],
|
||||
[self._data_finished])
|
||||
|
||||
return [stop_condition_net] + nets + [check_done_net], should_stop, fields
|
||||
|
||||
def setup_limiter(self, global_init_net, global_finish_net):
|
||||
"""Configure task level init/cleanup nets required to implement limit
|
||||
condition. Must be implemented by subclass.
|
||||
|
||||
Args:
|
||||
global_init_net: A net invoked at task init time.
|
||||
global_finish_net: A net invoked at task cleanup time.
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement `setup_limiter`")
|
||||
|
||||
def check_limiter_condition(self, stop_condition_net):
|
||||
"""Configure a net that is invoked between reading batches to see if
|
||||
limit condition is met. Must be implemented by subclass.
|
||||
|
||||
Args:
|
||||
stop_condition_net: A net invoked to evaluate an early termination
|
||||
condition.
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement `check_limiter_condition")
|
||||
|
||||
def data_finished(self):
|
||||
"""
|
||||
Return a blob that can be checked after the end of the reading task,
|
||||
which will contain a scalar float indicating whether the underlying
|
||||
reader has been exhausted (True) or whether we stopped because reached
|
||||
the limit of iterations (False).
|
||||
"""
|
||||
return self._data_finished
|
||||
|
||||
|
||||
class ReaderWithLimit(ReaderWithLimitBase):
|
||||
"""Reader that stops after `num_iter` batches.
|
||||
|
||||
If `num_iter` <= 0 or is None, reverts to an unconstrained reader that
|
||||
exports a boolean blob indicating that the reader has exhausted
|
||||
the data steam.
|
||||
"""
|
||||
def __init__(self, reader, num_iter=1):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
reader: The underlying reader object doing the actual read.
|
||||
num_iter: Number of batches to read. If `None`,
|
||||
the class reverts to a normal reader except that it also
|
||||
produces a data_finished blob as a side effect to indicate
|
||||
whether the input stream is exhausted.
|
||||
"""
|
||||
super().__init__(reader)
|
||||
self.counter = None
|
||||
self.num_iter = num_iter
|
||||
if self.num_iter is not None:
|
||||
self.counter = self.net.AddExternalInput(
|
||||
self.net.NextName('counter'))
|
||||
|
||||
def setup_limiter(self, global_init_net, global_finish_net):
|
||||
if self.counter:
|
||||
global_init_net.CreateCounter(
|
||||
[], [self.counter], init_count=int(self.num_iter))
|
||||
|
||||
def check_limiter_condition(self, stop_condition_net):
|
||||
if self.counter:
|
||||
return stop_condition_net.CountDown([self.counter], 1)
|
||||
else:
|
||||
return stop_condition_net.ConstantFill(
|
||||
[], 1,
|
||||
shape=[], value=False, dtype=core.DataType.BOOL)
|
||||
|
||||
|
||||
def CountUntil(num_iter):
|
||||
return ReaderWithLimit(CounterReader(), num_iter)
|
||||
|
||||
|
||||
class ReaderWithTimeLimit(ReaderWithLimitBase):
|
||||
"""Reader that stops after `duration` seconds.
|
||||
|
||||
If `duration` <= 0 or is None, reverts to an unconstrained reader that
|
||||
exports a boolean blob indicating that the reader has exhausted
|
||||
the data steam.
|
||||
"""
|
||||
def __init__(self, reader, duration=0):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
reader: The underlying reader object doing the actual read.
|
||||
duration: Number of seconds to read. If un-specified, None, or <= 0,
|
||||
the class reverts to a normal reader except that it also
|
||||
produces a data_finished blob as a side effect to indicate
|
||||
whether the input stream is exhausted.
|
||||
"""
|
||||
super().__init__(reader)
|
||||
|
||||
self.timer = None
|
||||
self.duration = duration
|
||||
self.duration_ns_blob = None
|
||||
|
||||
def setup_limiter(self, global_init_net, global_finish_net):
|
||||
if self.duration is not None and self.duration > 0:
|
||||
duration_ns = int(self.duration * (10**9))
|
||||
|
||||
self.timer = global_init_net.TimerBegin(
|
||||
[], counter_name='epoch_timer')
|
||||
start_time = global_init_net.TimerGet(self.timer)
|
||||
self.duration_ns_blob = global_init_net.ConstantFill(
|
||||
[start_time], value=duration_ns)
|
||||
|
||||
global_finish_net.TimerEnd([self.timer], [])
|
||||
|
||||
def check_limiter_condition(self, stop_condition_net):
|
||||
if self.duration:
|
||||
time_elapsed = stop_condition_net.TimerGet(self.timer)
|
||||
return stop_condition_net.GE(
|
||||
[time_elapsed, self.duration_ns_blob], str(self.should_stop))
|
||||
else:
|
||||
return stop_condition_net.ConstantFill(
|
||||
[], 1, shape=[], value=False, dtype=core.DataType.BOOL
|
||||
)
|
||||
|
||||
|
||||
class ReaderWithDelay(Reader):
|
||||
"""Test reader class that inserts a delay between reading batches."""
|
||||
|
||||
def __init__(self, reader, delay):
|
||||
Reader.__init__(self, schema=reader._schema)
|
||||
self.reader = reader
|
||||
self.delay = delay
|
||||
|
||||
def setup_ex(self, global_init_net, global_finish_net):
|
||||
self.reader.setup_ex(global_init_net, global_finish_net)
|
||||
|
||||
def read_ex(self, local_init_net, local_finish_net):
|
||||
read_net = core.Net("reader_body")
|
||||
|
||||
def sleep_op(*args, **argd):
|
||||
time.sleep(self.delay)
|
||||
|
||||
read_net.Python(sleep_op)([], [])
|
||||
return ([read_net],) + self.reader.read(read_net)
|
||||
|
||||
|
||||
class CompositeReader(Reader):
|
||||
"""
|
||||
Base class for a reader that wrap multiple readers, e.g., reading from
|
||||
multiple sources simultaneously.
|
||||
"""
|
||||
def __init__(self, names, readers):
|
||||
"""
|
||||
Args:
|
||||
names: list[str] names of readers; used as schema keys
|
||||
readers: list[Reader] Reader instances, must have schema
|
||||
"""
|
||||
assert len(names) == len(readers)
|
||||
super().__init__(schema=Struct(*[
|
||||
(name, reader.schema()) for name, reader in zip(names, readers)
|
||||
]))
|
||||
self._names = names
|
||||
self._readers = readers
|
||||
|
||||
def setup_ex(self, init_net, finish_net):
|
||||
for reader in self._readers:
|
||||
reader.setup_ex(init_net, finish_net)
|
||||
|
||||
def read_ex(self, local_init_net, local_finish_net):
|
||||
"""
|
||||
Stops when one of the reader finished
|
||||
"""
|
||||
# First, instantiate all the reader nets
|
||||
fields = []
|
||||
stop_blobs = []
|
||||
all_sub_read_nets = []
|
||||
for name, reader in zip(self._names, self._readers):
|
||||
sub_read_nets, should_stop, record = reader.read_record_ex(
|
||||
local_init_net, local_finish_net)
|
||||
stop_blobs.append(should_stop)
|
||||
all_sub_read_nets.append(sub_read_nets)
|
||||
fields.extend(record.field_blobs())
|
||||
|
||||
read_nets = []
|
||||
# Use the stop blob of the last reader as stop blob of composite reader.
|
||||
local_should_stop = stop_blobs[-1]
|
||||
for name, sub_read_nets, stop_blob in zip(self._names, all_sub_read_nets, stop_blobs):
|
||||
read_nets.extend(sub_read_nets)
|
||||
if stop_blob == local_should_stop:
|
||||
# Skip adding stop net because Or([A, A], A) doesn't pass operator
|
||||
# schema check
|
||||
continue
|
||||
stop_net = core.Net("{}_stop".format(name))
|
||||
stop_net.Or([local_should_stop, stop_blob], local_should_stop)
|
||||
read_nets.append(stop_net)
|
||||
|
||||
return read_nets, local_should_stop, fields
|
||||
|
||||
def reset(self, net):
|
||||
for reader in self._readers:
|
||||
reader.reset(net)
|
||||
|
||||
|
||||
class CompositeReaderBuilder(ReaderBuilder):
|
||||
"""
|
||||
A reader builder for CompositeReader
|
||||
"""
|
||||
def __init__(self, names, reader_builders):
|
||||
"""
|
||||
Args:
|
||||
names: list[str] names of readers; used as schema keys
|
||||
reader_builders: list[ReaderBuilder] ReaderBuilder instances;
|
||||
must have schema
|
||||
"""
|
||||
super().__init__()
|
||||
self._names = names
|
||||
self._reader_builders = reader_builders
|
||||
self._schema = Struct(*[
|
||||
(name, reader_builder.schema())
|
||||
for name, reader_builder in zip(names, reader_builders)
|
||||
])
|
||||
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
def setup(self, **kwargs):
|
||||
data_finished_blobs = {}
|
||||
# limiter is stateful; it can only be used once. Since
|
||||
# CompositeReader stops when one of the reader stops,
|
||||
# this is fine.
|
||||
if "limiter" in kwargs:
|
||||
limiter = kwargs.pop("limiter")
|
||||
else:
|
||||
limiter = None
|
||||
for i, reader_builder in enumerate(self._reader_builders):
|
||||
if i == len(self._reader_builders) - 1 and limiter is not None:
|
||||
# The limiter must be applied to the last reader so that the
|
||||
# batch counter is incremented only if every reader has data
|
||||
kwargs["limiter"] = limiter
|
||||
sub_reader_data_finished_blobs = reader_builder.setup(**kwargs)
|
||||
overlapping_keys = set(data_finished_blobs.keys()) & set(sub_reader_data_finished_blobs.keys())
|
||||
overlapping_values = set(data_finished_blobs.values()) & set(sub_reader_data_finished_blobs.values())
|
||||
assert overlapping_keys == set(), "Overlapping keys: {}".format(overlapping_keys)
|
||||
assert overlapping_values == set(), "Overlapping values: {}".format(overlapping_values)
|
||||
data_finished_blobs.update(sub_reader_data_finished_blobs)
|
||||
|
||||
return data_finished_blobs
|
||||
|
||||
def new_reader(self, **kwargs):
|
||||
readers = []
|
||||
for reader_builder in self._reader_builders:
|
||||
reader = reader_builder.new_reader(**kwargs)
|
||||
if isinstance(reader, Reader):
|
||||
pass
|
||||
elif hasattr(reader, 'reader'):
|
||||
reader = reader.reader()
|
||||
else:
|
||||
raise ValueError('reader must be an instance of Reader or Pipe')
|
||||
readers.append(reader)
|
||||
|
||||
multi_reader = CompositeReader(self._names, readers)
|
||||
assert multi_reader.schema() == self._schema
|
||||
return multi_reader
|
||||
445
caffe2/python/dataio_test.py
Normal file
445
caffe2/python/dataio_test.py
Normal file
@ -0,0 +1,445 @@
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
from caffe2.python.dataio import (
|
||||
CompositeReader,
|
||||
CompositeReaderBuilder,
|
||||
ReaderBuilder,
|
||||
ReaderWithDelay,
|
||||
ReaderWithLimit,
|
||||
ReaderWithTimeLimit,
|
||||
)
|
||||
from caffe2.python.dataset import Dataset
|
||||
from caffe2.python.db_file_reader import DBFileReader
|
||||
from caffe2.python.pipeline import pipe
|
||||
from caffe2.python.schema import Struct, NewRecord, FeedRecord
|
||||
from caffe2.python.session import LocalSession
|
||||
from caffe2.python.task import TaskGroup, final_output, WorkspaceType
|
||||
from caffe2.python.test_util import TestCase
|
||||
from caffe2.python.cached_reader import CachedReader
|
||||
from caffe2.python import core, workspace, schema
|
||||
from caffe2.python.net_builder import ops
|
||||
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
import tempfile
|
||||
|
||||
|
||||
def make_source_dataset(ws, size=100, offset=0, name=None):
|
||||
name = name or "src"
|
||||
src_init = core.Net("{}_init".format(name))
|
||||
with core.NameScope(name):
|
||||
src_values = Struct(('label', np.array(range(offset, offset + size))))
|
||||
src_blobs = NewRecord(src_init, src_values)
|
||||
src_ds = Dataset(src_blobs, name=name)
|
||||
FeedRecord(src_blobs, src_values, ws)
|
||||
ws.run(src_init)
|
||||
return src_ds
|
||||
|
||||
|
||||
def make_destination_dataset(ws, schema, name=None):
|
||||
name = name or 'dst'
|
||||
dst_init = core.Net('{}_init'.format(name))
|
||||
with core.NameScope(name):
|
||||
dst_ds = Dataset(schema, name=name)
|
||||
dst_ds.init_empty(dst_init)
|
||||
ws.run(dst_init)
|
||||
return dst_ds
|
||||
|
||||
|
||||
class TestReaderBuilder(ReaderBuilder):
|
||||
def __init__(self, name, size, offset):
|
||||
self._schema = schema.Struct(
|
||||
('label', schema.Scalar()),
|
||||
)
|
||||
self._name = name
|
||||
self._size = size
|
||||
self._offset = offset
|
||||
self._src_ds = None
|
||||
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
def setup(self, ws):
|
||||
self._src_ds = make_source_dataset(ws, offset=self._offset, size=self._size,
|
||||
name=self._name)
|
||||
return {}
|
||||
|
||||
def new_reader(self, **kwargs):
|
||||
return self._src_ds
|
||||
|
||||
|
||||
class TestCompositeReader(TestCase):
|
||||
@unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
|
||||
def test_composite_reader(self):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
num_srcs = 3
|
||||
names = ["src_{}".format(i) for i in range(num_srcs)]
|
||||
size = 100
|
||||
offsets = [i * size for i in range(num_srcs)]
|
||||
src_dses = [make_source_dataset(ws, offset=offset, size=size, name=name)
|
||||
for (name, offset) in zip(names, offsets)]
|
||||
|
||||
data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses]
|
||||
# Sanity check we didn't overwrite anything
|
||||
for d, offset in zip(data, offsets):
|
||||
npt.assert_array_equal(d, range(offset, offset + size))
|
||||
|
||||
# Make an identically-sized empty destination dataset
|
||||
dst_ds_schema = schema.Struct(
|
||||
*[
|
||||
(name, src_ds.content().clone_schema())
|
||||
for name, src_ds in zip(names, src_dses)
|
||||
]
|
||||
)
|
||||
dst_ds = make_destination_dataset(ws, dst_ds_schema)
|
||||
|
||||
with TaskGroup() as tg:
|
||||
reader = CompositeReader(names,
|
||||
[src_ds.reader() for src_ds in src_dses])
|
||||
pipe(reader, dst_ds.writer(), num_runtime_threads=3)
|
||||
session.run(tg)
|
||||
|
||||
for i in range(num_srcs):
|
||||
written_data = sorted(
|
||||
ws.fetch_blob(str(dst_ds.content()[names[i]].label())))
|
||||
npt.assert_array_equal(data[i], written_data, "i: {}".format(i))
|
||||
|
||||
@unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
|
||||
def test_composite_reader_builder(self):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
num_srcs = 3
|
||||
names = ["src_{}".format(i) for i in range(num_srcs)]
|
||||
size = 100
|
||||
offsets = [i * size for i in range(num_srcs)]
|
||||
src_ds_builders = [
|
||||
TestReaderBuilder(offset=offset, size=size, name=name)
|
||||
for (name, offset) in zip(names, offsets)
|
||||
]
|
||||
|
||||
# Make an identically-sized empty destination dataset
|
||||
dst_ds_schema = schema.Struct(
|
||||
*[
|
||||
(name, src_ds_builder.schema())
|
||||
for name, src_ds_builder in zip(names, src_ds_builders)
|
||||
]
|
||||
)
|
||||
dst_ds = make_destination_dataset(ws, dst_ds_schema)
|
||||
|
||||
with TaskGroup() as tg:
|
||||
reader_builder = CompositeReaderBuilder(
|
||||
names, src_ds_builders)
|
||||
reader_builder.setup(ws=ws)
|
||||
pipe(reader_builder.new_reader(), dst_ds.writer(),
|
||||
num_runtime_threads=3)
|
||||
session.run(tg)
|
||||
|
||||
for name, offset in zip(names, offsets):
|
||||
written_data = sorted(
|
||||
ws.fetch_blob(str(dst_ds.content()[name].label())))
|
||||
npt.assert_array_equal(range(offset, offset + size), written_data,
|
||||
"name: {}".format(name))
|
||||
|
||||
|
||||
class TestReaderWithLimit(TestCase):
|
||||
def test_runtime_threads(self):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
src_ds = make_source_dataset(ws)
|
||||
totals = [None] * 3
|
||||
|
||||
def proc(rec):
|
||||
# executed once
|
||||
with ops.task_init():
|
||||
counter1 = ops.CreateCounter([], ['global_counter'])
|
||||
counter2 = ops.CreateCounter([], ['global_counter2'])
|
||||
counter3 = ops.CreateCounter([], ['global_counter3'])
|
||||
# executed once per thread
|
||||
with ops.task_instance_init():
|
||||
task_counter = ops.CreateCounter([], ['task_counter'])
|
||||
# executed on each iteration
|
||||
ops.CountUp(counter1)
|
||||
ops.CountUp(task_counter)
|
||||
# executed once per thread
|
||||
with ops.task_instance_exit():
|
||||
with ops.loop(ops.RetrieveCount(task_counter)):
|
||||
ops.CountUp(counter2)
|
||||
ops.CountUp(counter3)
|
||||
# executed once
|
||||
with ops.task_exit():
|
||||
totals[0] = final_output(ops.RetrieveCount(counter1))
|
||||
totals[1] = final_output(ops.RetrieveCount(counter2))
|
||||
totals[2] = final_output(ops.RetrieveCount(counter3))
|
||||
return rec
|
||||
|
||||
# Read full data set from original reader
|
||||
with TaskGroup() as tg:
|
||||
pipe(src_ds.reader(), num_runtime_threads=8, processor=proc)
|
||||
session.run(tg)
|
||||
self.assertEqual(totals[0].fetch(), 100)
|
||||
self.assertEqual(totals[1].fetch(), 100)
|
||||
self.assertEqual(totals[2].fetch(), 8)
|
||||
|
||||
# Read with a count-limited reader
|
||||
with TaskGroup() as tg:
|
||||
q1 = pipe(src_ds.reader(), num_runtime_threads=2)
|
||||
q2 = pipe(
|
||||
ReaderWithLimit(q1.reader(), num_iter=25),
|
||||
num_runtime_threads=3)
|
||||
pipe(q2, processor=proc, num_runtime_threads=6)
|
||||
session.run(tg)
|
||||
self.assertEqual(totals[0].fetch(), 25)
|
||||
self.assertEqual(totals[1].fetch(), 25)
|
||||
self.assertEqual(totals[2].fetch(), 6)
|
||||
|
||||
def _test_limit_reader_init_shared(self, size):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
|
||||
# Make source dataset
|
||||
src_ds = make_source_dataset(ws, size=size)
|
||||
|
||||
# Make an identically-sized empty destination Dataset
|
||||
dst_ds = make_destination_dataset(ws, src_ds.content().clone_schema())
|
||||
|
||||
return ws, session, src_ds, dst_ds
|
||||
|
||||
def _test_limit_reader_shared(self, reader_class, size, expected_read_len,
|
||||
expected_read_len_threshold,
|
||||
expected_finish, num_threads, read_delay,
|
||||
**limiter_args):
|
||||
ws, session, src_ds, dst_ds = \
|
||||
self._test_limit_reader_init_shared(size)
|
||||
|
||||
# Read without limiter
|
||||
# WorkspaceType.GLOBAL is required because we are fetching
|
||||
# reader.data_finished() after the TaskGroup finishes.
|
||||
with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
|
||||
if read_delay > 0:
|
||||
reader = reader_class(ReaderWithDelay(src_ds.reader(),
|
||||
read_delay),
|
||||
**limiter_args)
|
||||
else:
|
||||
reader = reader_class(src_ds.reader(), **limiter_args)
|
||||
pipe(reader, dst_ds.writer(), num_runtime_threads=num_threads)
|
||||
session.run(tg)
|
||||
read_len = len(sorted(ws.blobs[str(dst_ds.content().label())].fetch()))
|
||||
|
||||
# Do a fuzzy match (expected_read_len +/- expected_read_len_threshold)
|
||||
# to eliminate flakiness for time-limited tests
|
||||
self.assertGreaterEqual(
|
||||
read_len,
|
||||
expected_read_len - expected_read_len_threshold)
|
||||
self.assertLessEqual(
|
||||
read_len,
|
||||
expected_read_len + expected_read_len_threshold)
|
||||
self.assertEqual(
|
||||
sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
|
||||
list(range(read_len))
|
||||
)
|
||||
self.assertEqual(ws.blobs[str(reader.data_finished())].fetch(),
|
||||
expected_finish)
|
||||
|
||||
def test_count_limit_reader_without_limit(self):
|
||||
# No iter count specified, should read all records.
|
||||
self._test_limit_reader_shared(ReaderWithLimit,
|
||||
size=100,
|
||||
expected_read_len=100,
|
||||
expected_read_len_threshold=0,
|
||||
expected_finish=True,
|
||||
num_threads=8,
|
||||
read_delay=0,
|
||||
num_iter=None)
|
||||
|
||||
def test_count_limit_reader_with_zero_limit(self):
|
||||
# Zero iter count specified, should read 0 records.
|
||||
self._test_limit_reader_shared(ReaderWithLimit,
|
||||
size=100,
|
||||
expected_read_len=0,
|
||||
expected_read_len_threshold=0,
|
||||
expected_finish=False,
|
||||
num_threads=8,
|
||||
read_delay=0,
|
||||
num_iter=0)
|
||||
|
||||
def test_count_limit_reader_with_low_limit(self):
|
||||
# Read with limit smaller than size of dataset
|
||||
self._test_limit_reader_shared(ReaderWithLimit,
|
||||
size=100,
|
||||
expected_read_len=10,
|
||||
expected_read_len_threshold=0,
|
||||
expected_finish=False,
|
||||
num_threads=8,
|
||||
read_delay=0,
|
||||
num_iter=10)
|
||||
|
||||
def test_count_limit_reader_with_high_limit(self):
|
||||
# Read with limit larger than size of dataset
|
||||
self._test_limit_reader_shared(ReaderWithLimit,
|
||||
size=100,
|
||||
expected_read_len=100,
|
||||
expected_read_len_threshold=0,
|
||||
expected_finish=True,
|
||||
num_threads=8,
|
||||
read_delay=0,
|
||||
num_iter=110)
|
||||
|
||||
def test_time_limit_reader_without_limit(self):
|
||||
# No duration specified, should read all records.
|
||||
self._test_limit_reader_shared(ReaderWithTimeLimit,
|
||||
size=100,
|
||||
expected_read_len=100,
|
||||
expected_read_len_threshold=0,
|
||||
expected_finish=True,
|
||||
num_threads=8,
|
||||
read_delay=0.1,
|
||||
duration=0)
|
||||
|
||||
def test_time_limit_reader_with_short_limit(self):
|
||||
# Read with insufficient time limit
|
||||
size = 50
|
||||
num_threads = 4
|
||||
sleep_duration = 0.25
|
||||
duration = 1
|
||||
expected_read_len = int(round(num_threads * duration / sleep_duration))
|
||||
# Because the time limit check happens before the delay + read op,
|
||||
# subtract a little bit of time to ensure we don't get in an extra read
|
||||
duration = duration - 0.25 * sleep_duration
|
||||
|
||||
# NOTE: `expected_read_len_threshold` was added because this test case
|
||||
# has significant execution variation under stress. Under stress, we may
|
||||
# read strictly less than the expected # of samples; anywhere from
|
||||
# [0,N] where N = expected_read_len.
|
||||
# Hence we set expected_read_len to N/2, plus or minus N/2.
|
||||
self._test_limit_reader_shared(ReaderWithTimeLimit,
|
||||
size=size,
|
||||
expected_read_len=expected_read_len / 2,
|
||||
expected_read_len_threshold=expected_read_len / 2,
|
||||
expected_finish=False,
|
||||
num_threads=num_threads,
|
||||
read_delay=sleep_duration,
|
||||
duration=duration)
|
||||
|
||||
def test_time_limit_reader_with_long_limit(self):
|
||||
# Read with ample time limit
|
||||
# NOTE: we don't use `expected_read_len_threshold` because the duration,
|
||||
# read_delay, and # threads should be more than sufficient
|
||||
self._test_limit_reader_shared(ReaderWithTimeLimit,
|
||||
size=50,
|
||||
expected_read_len=50,
|
||||
expected_read_len_threshold=0,
|
||||
expected_finish=True,
|
||||
num_threads=4,
|
||||
read_delay=0.2,
|
||||
duration=10)
|
||||
|
||||
|
||||
class TestDBFileReader(TestCase):
|
||||
def setUp(self):
|
||||
self.temp_paths = []
|
||||
|
||||
def tearDown(self):
|
||||
# In case any test method fails, clean up temp paths.
|
||||
for path in self.temp_paths:
|
||||
self._delete_path(path)
|
||||
|
||||
@staticmethod
|
||||
def _delete_path(path):
|
||||
if os.path.isfile(path):
|
||||
os.remove(path) # Remove file.
|
||||
elif os.path.isdir(path):
|
||||
shutil.rmtree(path) # Remove dir recursively.
|
||||
|
||||
def _make_temp_path(self):
|
||||
# Make a temp path as db_path.
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
temp_path = f.name
|
||||
self.temp_paths.append(temp_path)
|
||||
return temp_path
|
||||
|
||||
@staticmethod
|
||||
def _build_source_reader(ws, size):
|
||||
src_ds = make_source_dataset(ws, size)
|
||||
return src_ds.reader()
|
||||
|
||||
@staticmethod
|
||||
def _read_all_data(ws, reader, session):
|
||||
dst_ds = make_destination_dataset(ws, reader.schema().clone_schema())
|
||||
|
||||
with TaskGroup() as tg:
|
||||
pipe(reader, dst_ds.writer(), num_runtime_threads=8)
|
||||
session.run(tg)
|
||||
|
||||
return ws.blobs[str(dst_ds.content().label())].fetch()
|
||||
|
||||
@unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
|
||||
def test_cached_reader(self):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
db_path = self._make_temp_path()
|
||||
|
||||
# Read data for the first time.
|
||||
cached_reader1 = CachedReader(
|
||||
self._build_source_reader(ws, 100), db_path, loop_over=False,
|
||||
)
|
||||
build_cache_step = cached_reader1.build_cache_step()
|
||||
session.run(build_cache_step)
|
||||
|
||||
data = self._read_all_data(ws, cached_reader1, session)
|
||||
self.assertEqual(sorted(data), list(range(100)))
|
||||
|
||||
# Read data from cache.
|
||||
cached_reader2 = CachedReader(
|
||||
self._build_source_reader(ws, 200), db_path,
|
||||
)
|
||||
build_cache_step = cached_reader2.build_cache_step()
|
||||
session.run(build_cache_step)
|
||||
|
||||
data = self._read_all_data(ws, cached_reader2, session)
|
||||
self.assertEqual(sorted(data), list(range(100)))
|
||||
|
||||
self._delete_path(db_path)
|
||||
|
||||
# We removed cache so we expect to receive data from original reader.
|
||||
cached_reader3 = CachedReader(
|
||||
self._build_source_reader(ws, 300), db_path,
|
||||
)
|
||||
build_cache_step = cached_reader3.build_cache_step()
|
||||
session.run(build_cache_step)
|
||||
|
||||
data = self._read_all_data(ws, cached_reader3, session)
|
||||
self.assertEqual(sorted(data), list(range(300)))
|
||||
|
||||
self._delete_path(db_path)
|
||||
|
||||
@unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB")
|
||||
def test_db_file_reader(self):
|
||||
ws = workspace.C.Workspace()
|
||||
session = LocalSession(ws)
|
||||
db_path = self._make_temp_path()
|
||||
|
||||
# Build a cache DB file.
|
||||
cached_reader = CachedReader(
|
||||
self._build_source_reader(ws, 100),
|
||||
db_path=db_path,
|
||||
db_type='LevelDB',
|
||||
)
|
||||
build_cache_step = cached_reader.build_cache_step()
|
||||
session.run(build_cache_step)
|
||||
|
||||
# Read data from cache DB file.
|
||||
db_file_reader = DBFileReader(
|
||||
db_path=db_path,
|
||||
db_type='LevelDB',
|
||||
)
|
||||
data = self._read_all_data(ws, db_file_reader, session)
|
||||
self.assertEqual(sorted(data), list(range(100)))
|
||||
|
||||
self._delete_path(db_path)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user