mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Direct FBGEMM integraton into ATen (#13777)
Summary: This PR implements infrastructure for post-processing a model to apply int8 quantization to its `nn.Linear` modules. Highlights of the implementation: 1) Inputs and outputs are `float` (quantized and packed internally), but the weight is quantized and packed ahead of time for efficiency. This implementation performs well in small-batch size GEMM calls. It should not be considered a general-purpose quantized GEMM kernel. 2) Weight packing is dependent on machine architecture (e.g. vector register width), so it is done just-in-time. Concretely, it is done on model load for the weights and it is done during operator execution for the input value. 3) Biases are unquantized 4) We fail loudly if we are attempting to run this on a machine that does not support FBGEMM. This is because we do not want a model's numerics to differ based on which machine it is run on. A model containing these FBGEMM ops *must* be run with FBGEMM The API can be seen in the added test case. Highlights are: 1) `torch.jit.quantized.quantize_linear_modules` walks the module hierarchy of the passed-in Module and replaces all `nn.Linear` modules with a new `QuantizedLinear` module, which encapsulates the behavior described above. 2) `_pack()` and `_unpack()` script methods are present on `QuantizedLinear` modules. These methods should be called before serialization and after deserialization, respectively. This ensures that the weight matrix is properly packed for the running machine's architecture. Note that in the long term, we would like to move toward a more Pickle-style serialization technique, rather than having these explicit methods that mutate member values. This is blocked on being able to assign attributes in a ScriptMethod, among other things. Pull Request resolved: https://github.com/pytorch/pytorch/pull/13777 Differential Revision: D13383276 Pulled By: jamesr66a fbshipit-source-id: 00f29c9f34544add2b90107e3cf55a287802c344
This commit is contained in:
committed by
Facebook Github Bot
parent
614121c1ef
commit
acbd9c49b0
@ -198,6 +198,10 @@ include(ExternalProject)
|
|||||||
# ---[ Dependencies
|
# ---[ Dependencies
|
||||||
include(cmake/Dependencies.cmake)
|
include(cmake/Dependencies.cmake)
|
||||||
|
|
||||||
|
if(USE_FBGEMM)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_FBGEMM")
|
||||||
|
endif()
|
||||||
|
|
||||||
# ---[ Whitelist file if whitelist is specified
|
# ---[ Whitelist file if whitelist is specified
|
||||||
include(cmake/Whitelist.cmake)
|
include(cmake/Whitelist.cmake)
|
||||||
|
|
||||||
|
308
aten/src/ATen/native/QuantizedLinear.cpp
Normal file
308
aten/src/ATen/native/QuantizedLinear.cpp
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
#include "ATen/ATen.h"
|
||||||
|
#include "ATen/NativeFunctions.h"
|
||||||
|
#include "ATen/WrapDimUtilsMulti.h"
|
||||||
|
|
||||||
|
#ifdef USE_FBGEMM
|
||||||
|
#include "fbgemm/Fbgemm.h"
|
||||||
|
#include "fbgemm/QuantUtils.h"
|
||||||
|
#endif // USE_FBGEMM
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <cctype>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
namespace at {
|
||||||
|
namespace native {
|
||||||
|
|
||||||
|
#ifdef USE_FBGEMM
|
||||||
|
|
||||||
|
Tensor fbgemm_linear_int8_weight(
|
||||||
|
const Tensor& input,
|
||||||
|
const Tensor& weight,
|
||||||
|
const Tensor& packed,
|
||||||
|
const Tensor& col_offsets,
|
||||||
|
Scalar weight_scale,
|
||||||
|
Scalar weight_zero_point,
|
||||||
|
const Tensor& bias) {
|
||||||
|
// We make a strong guarantee that models using these operators will have the
|
||||||
|
// same numerics across different machines. Therefore, we do not provide a
|
||||||
|
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||||
|
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
||||||
|
|
||||||
|
// We call contiguous on `input` and `weight` here because these APIs all
|
||||||
|
// expect row-major tensor buffers.
|
||||||
|
auto* input_ptr = input.contiguous().data<float>();
|
||||||
|
auto* weight_ptr = weight.contiguous().data<int8_t>();
|
||||||
|
|
||||||
|
AT_ASSERT(input.dim() >= 2);
|
||||||
|
int64_t M = 1;
|
||||||
|
for (size_t i = 0; i < input.dim() - 1; ++i) {
|
||||||
|
M *= input.size(i);
|
||||||
|
}
|
||||||
|
int64_t K = input.size(input.dim() - 1);
|
||||||
|
AT_ASSERT(weight.dim() == 2);
|
||||||
|
AT_ASSERT(K == weight.size(1));
|
||||||
|
auto N = weight.size(0);
|
||||||
|
AT_ASSERT(bias.dim() == 1);
|
||||||
|
AT_ASSERT(bias.size(0) == N);
|
||||||
|
AT_ASSERT(weight_scale.isFloatingPoint());
|
||||||
|
AT_ASSERT(weight_zero_point.isIntegral());
|
||||||
|
|
||||||
|
// Calculate statistics for quantization of the input Tensor
|
||||||
|
float x_min, x_max;
|
||||||
|
fbgemm::FindMinMax(
|
||||||
|
/*m=*/input_ptr,
|
||||||
|
/*min=*/&x_min,
|
||||||
|
/*max=*/&x_max,
|
||||||
|
/*len=*/input.numel());
|
||||||
|
|
||||||
|
// Input tensor is quantized as 8-bit unsigned values
|
||||||
|
static constexpr int precision = 8;
|
||||||
|
static constexpr bool is_signed = false;
|
||||||
|
|
||||||
|
// Calculate scale and zero point for quantization of input tensor
|
||||||
|
auto q_params = fbgemm::ChooseQuantizationParams(
|
||||||
|
/*min=*/x_min,
|
||||||
|
/*max=*/x_max,
|
||||||
|
/*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
|
||||||
|
/*qmax=*/is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
|
||||||
|
/*preserve_sparsity=*/false);
|
||||||
|
|
||||||
|
q_params.precision = precision;
|
||||||
|
|
||||||
|
// This operation does the following:
|
||||||
|
// 1) Quantizes the input matrix given the statistics we've calculated above
|
||||||
|
// 2) Creates a "row buffer" vector with offset values that must be added
|
||||||
|
// to the integer matrix multiplication operation to ensure correctness
|
||||||
|
// 3) Packs the resulting quantized matrix into vector-register and cache
|
||||||
|
// friendly tiles.
|
||||||
|
//
|
||||||
|
// Note this is not executed eagerly, but rather within the fbgemmPacked call
|
||||||
|
// below.
|
||||||
|
fbgemm::PackAWithQuantRowOffset<uint8_t> packA(
|
||||||
|
/*trans=*/fbgemm::matrix_op_t::NoTranspose,
|
||||||
|
/*nRow=*/M,
|
||||||
|
/*nCol=*/K,
|
||||||
|
/*smat=*/input_ptr,
|
||||||
|
/*ld=*/K,
|
||||||
|
/*pmat=*/nullptr, // packA manages ownership of `pmat`
|
||||||
|
/*scale=*/q_params.scale,
|
||||||
|
/*zero_pt=*/q_params.zero_point);
|
||||||
|
|
||||||
|
// ReQuantizeForFloat requires pointers to the scale and zero point values,
|
||||||
|
// since in the case of rowwise quantization these will be arrays rather than
|
||||||
|
// scalars. But in this case, we're doing whole-tensor quantization so we just
|
||||||
|
// pass a pointer to the scale values (and internally ReQuantizeFor Float
|
||||||
|
// won't index past 0
|
||||||
|
float weight_scale_float = static_cast<float>(weight_scale.to<double>());
|
||||||
|
int32_t weight_zero_point_int32 =
|
||||||
|
static_cast<int32_t>(weight_zero_point.to<int64_t>());
|
||||||
|
|
||||||
|
// This is the end of the pipeline, pass the resulting matrix through
|
||||||
|
fbgemm::DoNothing<float, float> doNothingObj{};
|
||||||
|
|
||||||
|
// After the uint8 * int8 matrix multiplication is performed, this operation
|
||||||
|
// does:
|
||||||
|
// 1) Add in row and column offsets to the rows and columns, respectively
|
||||||
|
// 2) Dequantize the results into floating point
|
||||||
|
// 3) Add in the bias term
|
||||||
|
fbgemm::ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
|
||||||
|
/*nextop=*/doNothingObj,
|
||||||
|
/*Aq_scale=*/q_params.scale,
|
||||||
|
/*Bq_scale=*/&weight_scale_float,
|
||||||
|
/*Aq_zero_point=*/q_params.zero_point,
|
||||||
|
/*Bq_zero_point=*/&weight_zero_point_int32,
|
||||||
|
/*row_offsets=*/packA.getRowOffsetBuffer(),
|
||||||
|
/*col_offsets=*/col_offsets.data<int32_t>(),
|
||||||
|
/*bias=*/bias.contiguous().data<float>(),
|
||||||
|
/*ncol=*/N);
|
||||||
|
|
||||||
|
// Allocate output Tensor and a buffer for fbgemmPacked to use
|
||||||
|
auto output = at::zeros_like(bias).to(at::kFloat).expand({M, N}).contiguous();
|
||||||
|
auto buffer = at::zeros_like(output).to(at::kInt).contiguous();
|
||||||
|
|
||||||
|
// Pull out the PackBMatrix instance from the owning tensor
|
||||||
|
auto* packB = reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(
|
||||||
|
packed.storage().data_ptr().get());
|
||||||
|
|
||||||
|
// Do the GEMM
|
||||||
|
fbgemm::fbgemmPacked(
|
||||||
|
/*packA=*/packA,
|
||||||
|
/*packB=*/*packB,
|
||||||
|
/*C=*/output.data<float>(),
|
||||||
|
/*C_buffer=*/buffer.data<int32_t>(),
|
||||||
|
/*ldc=*/N,
|
||||||
|
/*outProcess=*/outputProcObj,
|
||||||
|
/*thread_id=*/0,
|
||||||
|
/*num_threads=*/1);
|
||||||
|
|
||||||
|
// The resulting matrix here is 2-D, let's view it with the original
|
||||||
|
// left hand dimensions of the input.
|
||||||
|
std::vector<int64_t> out_sizes = input.sizes().vec();
|
||||||
|
out_sizes.back() = N;
|
||||||
|
return output.view(out_sizes);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Calculate the column offsets
|
||||||
|
// Note this includes the sum of the columns as well as the scalar term
|
||||||
|
// B_zero_point * K, whereas the row_offsets created by PackAWithQuantRowOffset
|
||||||
|
// is only the sum of the A rows.
|
||||||
|
void calc_col_offsets_transpose(
|
||||||
|
int K,
|
||||||
|
int N,
|
||||||
|
const int8_t* Bint8,
|
||||||
|
int32_t B_zero_point,
|
||||||
|
int32_t* col_offsets) {
|
||||||
|
for (size_t i = 0; i < N; ++i) {
|
||||||
|
int32_t sum = 0;
|
||||||
|
for (size_t j = 0; j < K; ++j) {
|
||||||
|
sum += Bint8[i * K + j];
|
||||||
|
}
|
||||||
|
col_offsets[i] = sum - B_zero_point * K;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
|
||||||
|
const Tensor& weight) {
|
||||||
|
// We make a strong guarantee that models using these operators will have the
|
||||||
|
// same numerics across different machines. Therefore, we do not provide a
|
||||||
|
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||||
|
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
||||||
|
auto weight_contig = weight.contiguous();
|
||||||
|
|
||||||
|
// Calculate weight statistics
|
||||||
|
float w_min, w_max;
|
||||||
|
fbgemm::FindMinMax(
|
||||||
|
/*m=*/weight_contig.data<float>(),
|
||||||
|
/*min=*/&w_min,
|
||||||
|
/*max=*/&w_max,
|
||||||
|
/*len=*/weight_contig.numel());
|
||||||
|
|
||||||
|
// Choose parameters for quantizing the weight as 8-bit signed integer
|
||||||
|
static constexpr bool is_signed = true;
|
||||||
|
static constexpr int precision = 8;
|
||||||
|
auto q_params = fbgemm::ChooseQuantizationParams(
|
||||||
|
/*min=*/w_min,
|
||||||
|
/*max=*/w_max,
|
||||||
|
/*qmin=*/is_signed ? -(1 << (precision - 1)) : 0,
|
||||||
|
/*qmax=*/is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1,
|
||||||
|
/*preserve_sparsity=*/false);
|
||||||
|
|
||||||
|
q_params.precision = precision;
|
||||||
|
|
||||||
|
auto quantized = at::zeros_like(weight_contig).to(at::kChar).contiguous();
|
||||||
|
fbgemm::Quantize<int8_t>(
|
||||||
|
/*src=*/weight_contig.data<float>(),
|
||||||
|
/*dst=*/quantized.data<int8_t>(),
|
||||||
|
/*len=*/weight_contig.numel(),
|
||||||
|
/*qparams=*/q_params);
|
||||||
|
|
||||||
|
// Calculate column offsets of the weight and store them away in a tensor.
|
||||||
|
// Similarly to quantization, this can be done once and cached.
|
||||||
|
auto col_offsets =
|
||||||
|
at::zeros_like(quantized).sum({1}).to(at::kInt).contiguous();
|
||||||
|
calc_col_offsets_transpose(
|
||||||
|
/*K=*/quantized.size(1),
|
||||||
|
/*N=*/quantized.size(0),
|
||||||
|
/*Bint8=*/quantized.data<int8_t>(),
|
||||||
|
/*B_zero_point=*/q_params.zero_point,
|
||||||
|
/*col_offsets=*/col_offsets.data<int32_t>());
|
||||||
|
|
||||||
|
return std::make_tuple(
|
||||||
|
quantized, col_offsets, q_params.scale, q_params.zero_point);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool fbgemm_is_cpu_supported() {
|
||||||
|
return fbgemm::fbgemmSupportedCPU();
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor fbgemm_pack_quantized_matrix(
|
||||||
|
const Tensor& weight,
|
||||||
|
int64_t K,
|
||||||
|
int64_t N) {
|
||||||
|
// We make a strong guarantee that models using these operators will have the
|
||||||
|
// same numerics across different machines. Therefore, we do not provide a
|
||||||
|
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||||
|
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
||||||
|
auto contiguous_ptr = weight.contiguous().data<int8_t>();
|
||||||
|
auto* ptr = new fbgemm::PackBMatrix<int8_t>(
|
||||||
|
/*trans=*/fbgemm::matrix_op_t::Transpose,
|
||||||
|
/*nRow=*/K,
|
||||||
|
/*nCol=*/N,
|
||||||
|
/*smat=*/contiguous_ptr,
|
||||||
|
/*ld=*/K,
|
||||||
|
/*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
|
||||||
|
/*groups=*/1);
|
||||||
|
|
||||||
|
// We store this instance away in a Tensor and register a deleter function
|
||||||
|
// so that we do not leak memory. On the other side, we pull out the storage's
|
||||||
|
// data_ptr and get the PackBMatrix's pointer.
|
||||||
|
at::DataPtr at_ptr(
|
||||||
|
ptr,
|
||||||
|
ptr,
|
||||||
|
[](void* ptr) {
|
||||||
|
fbgemm::PackBMatrix<int8_t>* typed_ptr =
|
||||||
|
reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(ptr);
|
||||||
|
delete typed_ptr;
|
||||||
|
},
|
||||||
|
at::kCPU);
|
||||||
|
|
||||||
|
auto retval = at::empty(
|
||||||
|
{sizeof(fbgemm::PackBMatrix<int8_t>)}, weight.options().dtype(at::kByte));
|
||||||
|
|
||||||
|
retval.storage().set_data_ptr(std::move(at_ptr));
|
||||||
|
|
||||||
|
return retval;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else // USE_FBGEMM
|
||||||
|
|
||||||
|
Tensor fbgemm_linear_int8_weight(
|
||||||
|
const Tensor& /*input*/,
|
||||||
|
const Tensor& /*weight*/,
|
||||||
|
const Tensor& /*packed*/,
|
||||||
|
const Tensor& /*col_offsets*/,
|
||||||
|
Scalar /*weight_scale*/,
|
||||||
|
Scalar /*weight_zero_point*/,
|
||||||
|
const Tensor& /*bias*/) {
|
||||||
|
// We make a strong guarantee that models using these operators will have the
|
||||||
|
// same numerics across different machines. Therefore, we do not provide a
|
||||||
|
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||||
|
AT_ASSERTM(
|
||||||
|
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
|
||||||
|
const Tensor& /*weight*/) {
|
||||||
|
// We make a strong guarantee that models using these operators will have the
|
||||||
|
// same numerics across different machines. Therefore, we do not provide a
|
||||||
|
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||||
|
AT_ASSERTM(
|
||||||
|
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor fbgemm_pack_quantized_matrix(
|
||||||
|
const Tensor& /*input*/,
|
||||||
|
int64_t /*K*/,
|
||||||
|
int64_t /*N*/) {
|
||||||
|
// We make a strong guarantee that models using these operators will have the
|
||||||
|
// same numerics across different machines. Therefore, we do not provide a
|
||||||
|
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||||
|
AT_ASSERTM(
|
||||||
|
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool fbgemm_is_cpu_supported() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // USE_FBGEMM
|
||||||
|
}
|
||||||
|
} // namespace at
|
@ -956,6 +956,14 @@
|
|||||||
|
|
||||||
- func: linear(Tensor input, Tensor weight, Tensor? bias={}) -> Tensor
|
- func: linear(Tensor input, Tensor weight, Tensor? bias={}) -> Tensor
|
||||||
|
|
||||||
|
- func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
|
||||||
|
|
||||||
|
- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, double, int64_t)
|
||||||
|
|
||||||
|
- func: fbgemm_pack_quantized_matrix(Tensor input, int64_t K, int64_t N) -> Tensor
|
||||||
|
|
||||||
|
- func: fbgemm_is_cpu_supported() -> bool
|
||||||
|
|
||||||
- func: linspace(Scalar start, Scalar end, int64_t steps=100, TensorOptions options={}) -> Tensor
|
- func: linspace(Scalar start, Scalar end, int64_t steps=100, TensorOptions options={}) -> Tensor
|
||||||
|
|
||||||
- func: linspace_out(Tensor result, Scalar start, Scalar end, int64_t steps=100) -> Tensor
|
- func: linspace_out(Tensor result, Scalar start, Scalar end, int64_t steps=100) -> Tensor
|
||||||
|
@ -350,6 +350,9 @@ endif()
|
|||||||
if(USE_FBGEMM)
|
if(USE_FBGEMM)
|
||||||
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
||||||
include_directories(SYSTEM "${CAFFE2_THIRD_PARTY_ROOT}")
|
include_directories(SYSTEM "${CAFFE2_THIRD_PARTY_ROOT}")
|
||||||
|
caffe2_update_option(USE_FBGEMM ON)
|
||||||
|
else()
|
||||||
|
caffe2_update_option(USE_FBGEMM OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import torch
|
|||||||
import torch.jit
|
import torch.jit
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch.jit.quantized
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from itertools import product, chain
|
from itertools import product, chain
|
||||||
import torch.jit.frontend
|
import torch.jit.frontend
|
||||||
@ -8006,6 +8007,43 @@ a")
|
|||||||
|
|
||||||
traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
|
traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
|
||||||
|
|
||||||
|
# These tests don't work because UBSAN has a false positive about accessing
|
||||||
|
# out of bounds on a dynamically sized struct internal to asmjit
|
||||||
|
if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
|
||||||
|
def test_int8_quantization_module(self):
|
||||||
|
K1, N1 = 2, 2
|
||||||
|
|
||||||
|
class FooBar(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(FooBar, self).__init__()
|
||||||
|
self.linear1 = torch.nn.Linear(K1, N1).float()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
fb = FooBar()
|
||||||
|
fb.linear1.weight = torch.nn.Parameter(
|
||||||
|
torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), requires_grad=False)
|
||||||
|
fb.linear1.bias = torch.nn.Parameter(torch.zeros_like(fb.linear1.bias), requires_grad=False)
|
||||||
|
fb_ref = FooBar()
|
||||||
|
fb_ref.linear1.weight = torch.nn.Parameter(fb.linear1.weight.clone(), requires_grad=False)
|
||||||
|
fb_ref.linear1.bias = torch.nn.Parameter(fb.linear1.bias.clone(), requires_grad=False)
|
||||||
|
torch.jit.quantized.quantize_linear_modules(fb)
|
||||||
|
|
||||||
|
x = (torch.rand(1, K1).float() - 0.5) / 10.0
|
||||||
|
traced = torch.jit.trace(fb, (x,))
|
||||||
|
traced.apply(lambda s: s._pack() if s._has_method('_pack') else None)
|
||||||
|
fb = self.getExportImportCopy(traced)
|
||||||
|
traced.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
|
||||||
|
|
||||||
|
fb.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
|
||||||
|
|
||||||
|
x = torch.tensor([[100, -150]], dtype=torch.float)
|
||||||
|
y = fb(x)
|
||||||
|
y_ref = fb_ref(x)
|
||||||
|
torch.testing.assert_allclose(y, y_ref, rtol=0.0001, atol=1e-3)
|
||||||
|
|
||||||
def checkTracerWarning(self, *args, **kwargs):
|
def checkTracerWarning(self, *args, **kwargs):
|
||||||
with warnings.catch_warnings(record=True) as warns:
|
with warnings.catch_warnings(record=True) as warns:
|
||||||
torch.jit.trace(*args, **kwargs)
|
torch.jit.trace(*args, **kwargs)
|
||||||
@ -9112,7 +9150,7 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
|
|||||||
self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
|
self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _test_snli(self, device, check_export_import=True):
|
def _test_snli(self, device, check_export_import=True, quantized=False):
|
||||||
class Bottle(nn.Module):
|
class Bottle(nn.Module):
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
@ -9199,13 +9237,26 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
|
|||||||
premise = torch.LongTensor(48, 128).random_(0, 100).to(device)
|
premise = torch.LongTensor(48, 128).random_(0, 100).to(device)
|
||||||
hypothesis = torch.LongTensor(24, 128).random_(0, 100).to(device)
|
hypothesis = torch.LongTensor(24, 128).random_(0, 100).to(device)
|
||||||
|
|
||||||
self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
|
if quantized:
|
||||||
inputs_require_grads=False, export_import=check_export_import)
|
snli = SNLIClassifier(Config()).cpu()
|
||||||
|
torch.jit.quantized.quantize_linear_modules(snli)
|
||||||
|
# we don't do export/import checks because we would need to call
|
||||||
|
# _pack/_unpack
|
||||||
|
self.checkTrace(snli, (premise, hypothesis), inputs_require_grads=False,
|
||||||
|
export_import=False)
|
||||||
|
else:
|
||||||
|
self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
|
||||||
|
inputs_require_grads=False, export_import=check_export_import)
|
||||||
|
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_snli(self):
|
def test_snli(self):
|
||||||
self._test_snli(self, device='cpu')
|
self._test_snli(self, device='cpu')
|
||||||
|
|
||||||
|
if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
|
||||||
|
@skipIfRocm
|
||||||
|
def test_snli_quantized(self):
|
||||||
|
self._test_snli(self, device='cpu', quantized=True)
|
||||||
|
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||||
def test_snli_cuda(self):
|
def test_snli_cuda(self):
|
||||||
@ -9308,7 +9359,7 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
|
|||||||
export_import=False)
|
export_import=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _test_vae(self, device, check_export_import=True):
|
def _test_vae(self, device, check_export_import=True, quantized=False):
|
||||||
class VAE(nn.Module):
|
class VAE(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(VAE, self).__init__()
|
super(VAE, self).__init__()
|
||||||
@ -9340,13 +9391,26 @@ class TestEndToEndHybridFrontendModels(JitTestCase):
|
|||||||
z = self.reparameterize(mu, logvar)
|
z = self.reparameterize(mu, logvar)
|
||||||
return self.decode(z), mu, logvar
|
return self.decode(z), mu, logvar
|
||||||
|
|
||||||
# eval() is present because randn_like makes this nondeterministic
|
if quantized:
|
||||||
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
|
vae = VAE().to(device).eval()
|
||||||
export_import=check_export_import)
|
torch.jit.quantized.quantize_linear_modules(vae)
|
||||||
|
# We don't do export/import checks because we would need to call
|
||||||
|
# _unpack and _pack
|
||||||
|
self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device),),
|
||||||
|
export_import=False, allow_unused=True,
|
||||||
|
inputs_require_grads=False)
|
||||||
|
else:
|
||||||
|
# eval() is present because randn_like makes this nondeterministic
|
||||||
|
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
|
||||||
|
export_import=check_export_import)
|
||||||
|
|
||||||
def test_vae(self):
|
def test_vae(self):
|
||||||
self._test_vae(self, device='cpu')
|
self._test_vae(self, device='cpu')
|
||||||
|
|
||||||
|
if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
|
||||||
|
def test_vae_quantized(self):
|
||||||
|
self._test_vae(self, device='cpu', quantized=True)
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||||
def test_vae_cuda(self):
|
def test_vae_cuda(self):
|
||||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||||
|
@ -123,6 +123,7 @@ ${name}(${py_formal_args})""")
|
|||||||
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
|
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
|
||||||
SUPPORTED_RETURN_TYPES = {
|
SUPPORTED_RETURN_TYPES = {
|
||||||
'Tensor', 'std::tuple<Tensor,Tensor>',
|
'Tensor', 'std::tuple<Tensor,Tensor>',
|
||||||
|
'std::tuple<Tensor,Tensor,double,int64_t>',
|
||||||
'std::tuple<Tensor,Tensor,Tensor>',
|
'std::tuple<Tensor,Tensor,Tensor>',
|
||||||
'std::tuple<Tensor,Tensor,Tensor,Tensor>',
|
'std::tuple<Tensor,Tensor,Tensor,Tensor>',
|
||||||
'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
|
'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
|
||||||
|
@ -95,6 +95,16 @@ inline PyObject* wrap(at::Scalar scalar) {
|
|||||||
return wrap(scalar_to_tensor(scalar));
|
return wrap(scalar_to_tensor(scalar));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, float, int64_t> tensors) {
|
||||||
|
auto r = THPObjectPtr{PyTuple_New(4)};
|
||||||
|
if (!r) throw python_error();
|
||||||
|
PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors))));
|
||||||
|
PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors))));
|
||||||
|
PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors))));
|
||||||
|
PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors))));
|
||||||
|
return r.release();
|
||||||
|
}
|
||||||
|
|
||||||
inline PyObject* wrap(THPDtype *dtype) {
|
inline PyObject* wrap(THPDtype *dtype) {
|
||||||
Py_INCREF(dtype);
|
Py_INCREF(dtype);
|
||||||
return (PyObject*)dtype;
|
return (PyObject*)dtype;
|
||||||
|
54
torch/jit/quantized.py
Normal file
54
torch/jit/quantized.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedLinear(torch.jit.ScriptModule):
|
||||||
|
__constants__ = ['scale', 'zero_point']
|
||||||
|
|
||||||
|
def __init__(self, other):
|
||||||
|
super(QuantizedLinear, self).__init__()
|
||||||
|
self.in_features = other.in_features
|
||||||
|
self.out_features = other.out_features
|
||||||
|
# Quantize weight and discard the original
|
||||||
|
self.weight, self.col_offsets, self.scale, self.zero_point = torch.fbgemm_linear_quantize_weight(
|
||||||
|
other.weight.clone().float())
|
||||||
|
self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
|
||||||
|
self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
|
||||||
|
assert other.bias is not None, 'QuantizedLinear requires a bias'
|
||||||
|
self.bias = torch.nn.Parameter(other.bias.clone().float())
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
'packed_tensor_ptr',
|
||||||
|
torch.fbgemm_pack_quantized_matrix(self.weight.clone(), self.weight.size(1), self.weight.size(0)))
|
||||||
|
|
||||||
|
@torch.jit.script_method
|
||||||
|
def _unpack(self):
|
||||||
|
self.packed_tensor_ptr.set_(
|
||||||
|
torch.fbgemm_pack_quantized_matrix(
|
||||||
|
self.weight, self.weight.size(1), self.weight.size(0)))
|
||||||
|
|
||||||
|
@torch.jit.script_method
|
||||||
|
def _pack(self):
|
||||||
|
self.packed_tensor_ptr.set_(
|
||||||
|
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
|
||||||
|
|
||||||
|
@torch.jit.script_method
|
||||||
|
def forward(self, input):
|
||||||
|
out = torch.fbgemm_linear_int8_weight(
|
||||||
|
input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets,
|
||||||
|
self.scale, self.zero_point, self.bias)
|
||||||
|
return out.type_as(input)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
repr = 'in_features={in_features}, out_features={out_features}, ' \
|
||||||
|
'scale={scale}, zero_point={zero_point}'.format(**self.__dict__)
|
||||||
|
return repr
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_linear_modules(module):
|
||||||
|
for name, mod in module.named_modules():
|
||||||
|
if mod is module:
|
||||||
|
continue
|
||||||
|
if isinstance(mod, torch.nn.Linear):
|
||||||
|
setattr(module, name, QuantizedLinear(mod))
|
||||||
|
quantize_linear_modules(mod)
|
@ -1 +1,2 @@
|
|||||||
vptr:libtorch.so
|
vptr:libtorch.so
|
||||||
|
bounds:asmjit::Zone::_alloc
|
||||||
|
Reference in New Issue
Block a user