mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Hooking backward for QNNPACK (#94432)
Summary: Enabling quantized gradient. Test Plan: Algorithmic correctness - Dequantized matmul vs QNNPACK matmul for gradient - P616202766 ``` dequantized matmul : [1.5463, -0.2917, -2.1735, 0.5689, -1.0795] QNNPACK matmul : tensor([[ 1.5463, -0.2917, -2.1735, 0.5689, -1.0795]]) ``` Differential Revision: D42593235 Pull Request resolved: https://github.com/pytorch/pytorch/pull/94432 Approved by: https://github.com/malfet, https://github.com/kimishpatel
This commit is contained in:
committed by
PyTorch MergeBot
parent
92edac72aa
commit
a1d7014c0f
@ -15,6 +15,7 @@
|
||||
#else
|
||||
#include <ATen/ops/_empty_affine_quantized.h>
|
||||
#include <ATen/ops/_empty_per_channel_affine_quantized.h>
|
||||
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/from_blob.h>
|
||||
#endif
|
||||
@ -65,13 +66,20 @@ std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedLinearWeightsQnnp::
|
||||
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(orig_weight, bias_);
|
||||
}
|
||||
else{
|
||||
TORCH_WARN(
|
||||
"Original weight is freed, we are converting pre-packed weight to original weight.");
|
||||
uint8_t* kernel = w->unpackWeights(w_zero_points.data(), n_elements);
|
||||
at::Tensor original_tensor = at::from_blob(kernel, weight_sizes, c10::kByte).clone().toType(c10::kQInt8);
|
||||
original_tensor.sub_(128);
|
||||
free(kernel);
|
||||
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(original_tensor, bias_);
|
||||
float* weight_scales_data = w_scales.data_ptr<float>();
|
||||
at::Tensor weight_origin;
|
||||
weight_origin = at::empty(weight_sizes, at::device(c10::kCPU).dtype(at::kChar));
|
||||
int8_t* weight_ptr_int8 =
|
||||
reinterpret_cast<int8_t*>(weight_origin.data_ptr<int8_t>());
|
||||
w->unpackWeights(w_zero_points.data(), weight_ptr_int8);
|
||||
// See for the subtraction 128
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp#L319
|
||||
weight_origin.sub_(128);
|
||||
// As of now, we are supporting only per tensor quantizer
|
||||
// TO-DO : Support a per channel as well.
|
||||
at::Tensor original_quantized_tensor = at::_make_per_tensor_quantized_tensor(weight_origin, weight_scales_data[0], w_zero_points[0]);
|
||||
TORCH_CHECK(original_quantized_tensor.qscheme() == c10::kPerTensorAffine);
|
||||
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(original_quantized_tensor, bias_);
|
||||
}
|
||||
}
|
||||
#endif // USE_PYTORCH_QNNPACK
|
||||
|
@ -50,7 +50,6 @@ struct PackedLinearWeightsQnnp : public LinearPackedParamsBase {
|
||||
w_scales(std::move(w_scales)),
|
||||
w_zero_points(std::move(w_zps)) {
|
||||
weight_sizes = this->orig_weight.sizes().vec();
|
||||
n_elements = std::accumulate(std::begin(weight_sizes), std::end(weight_sizes), 1, std::multiplies<double>());
|
||||
}
|
||||
|
||||
std::unique_ptr<qnnpack::PackBMatrix> w;
|
||||
@ -62,7 +61,6 @@ struct PackedLinearWeightsQnnp : public LinearPackedParamsBase {
|
||||
std::vector<uint8_t> w_zero_points;
|
||||
std::vector<float> requantization_scales;
|
||||
std::vector<int64_t> weight_sizes;
|
||||
int n_elements;
|
||||
|
||||
at::Tensor apply(
|
||||
at::Tensor input,
|
||||
|
@ -66,9 +66,9 @@ class PackBMatrix final {
|
||||
return packed_weights_;
|
||||
}
|
||||
|
||||
uint8_t* unpackWeights(
|
||||
void unpackWeights(
|
||||
const uint8_t* kernel_zero_points,
|
||||
int n_elements
|
||||
int8_t* kernel
|
||||
) const;
|
||||
|
||||
size_t getInputChannels() const
|
||||
|
@ -32,7 +32,6 @@ PackBMatrix::PackBMatrix(
|
||||
|
||||
const uint32_t n_stride = (output_channels + (nr - 1)) & -nr;
|
||||
const uint32_t k_stride = (input_channels + (kr - 1)) & -kr;
|
||||
|
||||
input_channels_ = input_channels;
|
||||
output_channels_ = output_channels;
|
||||
packed_weights_ =
|
||||
|
@ -8,9 +8,9 @@
|
||||
|
||||
namespace qnnpack {
|
||||
// For runtime quantization unpacking.
|
||||
uint8_t* PackBMatrix::unpackWeights(
|
||||
void PackBMatrix::unpackWeights(
|
||||
const uint8_t* kernel_zero_points,
|
||||
int n_elements
|
||||
int8_t* kernel
|
||||
) const {
|
||||
union {
|
||||
void* const as_void_ptr;
|
||||
@ -18,8 +18,6 @@ uint8_t* PackBMatrix::unpackWeights(
|
||||
int32_t* as_int32_ptr;
|
||||
} packed = {packed_weights_};
|
||||
|
||||
uint8_t* kernel = (uint8_t*)malloc(n_elements * sizeof(uint8_t));;
|
||||
|
||||
// C = A * B
|
||||
// A = M*K
|
||||
// B = K*N
|
||||
@ -67,7 +65,6 @@ uint8_t* PackBMatrix::unpackWeights(
|
||||
}
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
} // namespace qnnpack
|
||||
|
@ -1463,6 +1463,7 @@ def define_buck_targets(
|
||||
"torch/csrc/jit/mobile/train/random.cpp",
|
||||
"torch/csrc/jit/mobile/train/sequential.cpp",
|
||||
":gen_aten_libtorch[autograd/generated/Functions.cpp]",
|
||||
"torch/csrc/quantized/quantized_backward.cpp",
|
||||
],
|
||||
compiler_flags = get_pt_compiler_flags(),
|
||||
exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],
|
||||
|
77
torch/csrc/quantized/quantized_backward.cpp
Normal file
77
torch/csrc/quantized/quantized_backward.cpp
Normal file
@ -0,0 +1,77 @@
|
||||
#include <ATen/native/quantized/PackedParams.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
namespace {
|
||||
using namespace torch::autograd;
|
||||
using namespace at;
|
||||
// This class is a custom gradient function that enables quantized tensor to
|
||||
// pass input gradient back to the previous layers This function can be used
|
||||
// when the user is adapting mixed precision for traninig after quantization
|
||||
// From torch layer, we have no access to linear_dynamic operator which needs to
|
||||
// access via redispatching mechanism TO-DO : currently we are supporting per
|
||||
// tensor quantization only, will expand to per channel later on
|
||||
class PackedLinearWeightDynamicBackward
|
||||
: public Function<PackedLinearWeightDynamicBackward> {
|
||||
public:
|
||||
static torch::Tensor forward(
|
||||
AutogradContext* ctx,
|
||||
at::Tensor input,
|
||||
const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
|
||||
bool reduce_range) {
|
||||
static auto op =
|
||||
at::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("quantized::linear_dynamic", "")
|
||||
.typed<at::Tensor(
|
||||
at::Tensor,
|
||||
c10::intrusive_ptr<
|
||||
LinearPackedParamsBase,
|
||||
c10::detail::intrusive_target_default_null_type<
|
||||
LinearPackedParamsBase>> const&,
|
||||
bool)>();
|
||||
auto output = op.redispatch(
|
||||
DispatchKeySet({DispatchKey::CPU}), input, packed_weight, reduce_range);
|
||||
// TO-DO: passing packed_weight as saved_data requires more work in adding
|
||||
// LinearPackedParamsBase in ivalue For now, we can simply pass a weight
|
||||
// itself. Referenced :
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/ivalue.h
|
||||
auto unpacked_parameters = packed_weight->unpack();
|
||||
ctx->saved_data["weight"] = std::get<0>(unpacked_parameters);
|
||||
return output;
|
||||
}
|
||||
|
||||
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
|
||||
auto original_weight = ctx->saved_data["weight"].toTensor();
|
||||
original_weight = at::permute(original_weight, {1, 0});
|
||||
auto grad_output = grad_outputs[0];
|
||||
static auto op = at::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("quantized::linear_prepack", "")
|
||||
.typed<c10::intrusive_ptr<LinearPackedParamsBase>(
|
||||
at::Tensor, c10::optional<at::Tensor>)>();
|
||||
auto prepacked_weight = op.call(original_weight, nullopt);
|
||||
auto grad_input = prepacked_weight->apply_dynamic(grad_output);
|
||||
return {grad_input, torch::Tensor(), torch::Tensor()};
|
||||
}
|
||||
};
|
||||
|
||||
at::Tensor packed_linear_weight_grad(
|
||||
c10::DispatchKeySet ks,
|
||||
at::Tensor input,
|
||||
const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
|
||||
bool reduce_range) {
|
||||
return PackedLinearWeightDynamicBackward::apply(
|
||||
input, packed_weight, reduce_range);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
TORCH_LIBRARY_IMPL(quantized, Autograd, m) {
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("quantized::linear_dynamic"),
|
||||
TORCH_FN(packed_linear_weight_grad));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace native
|
||||
} // namespace at
|
Reference in New Issue
Block a user