Hooking backward for QNNPACK (#94432)

Summary: Enabling quantized gradient.

Test Plan:
Algorithmic correctness - Dequantized matmul vs QNNPACK matmul for gradient - P616202766

```
dequantized matmul : [1.5463, -0.2917, -2.1735, 0.5689, -1.0795]
QNNPACK matmul : tensor([[ 1.5463, -0.2917, -2.1735,  0.5689, -1.0795]])
```

Differential Revision: D42593235

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94432
Approved by: https://github.com/malfet, https://github.com/kimishpatel
This commit is contained in:
Kwanghoon An
2023-03-08 10:21:32 +00:00
committed by PyTorch MergeBot
parent 92edac72aa
commit a1d7014c0f
7 changed files with 97 additions and 17 deletions

View File

@ -15,6 +15,7 @@
#else
#include <ATen/ops/_empty_affine_quantized.h>
#include <ATen/ops/_empty_per_channel_affine_quantized.h>
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/from_blob.h>
#endif
@ -65,13 +66,20 @@ std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedLinearWeightsQnnp::
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(orig_weight, bias_);
}
else{
TORCH_WARN(
"Original weight is freed, we are converting pre-packed weight to original weight.");
uint8_t* kernel = w->unpackWeights(w_zero_points.data(), n_elements);
at::Tensor original_tensor = at::from_blob(kernel, weight_sizes, c10::kByte).clone().toType(c10::kQInt8);
original_tensor.sub_(128);
free(kernel);
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(original_tensor, bias_);
float* weight_scales_data = w_scales.data_ptr<float>();
at::Tensor weight_origin;
weight_origin = at::empty(weight_sizes, at::device(c10::kCPU).dtype(at::kChar));
int8_t* weight_ptr_int8 =
reinterpret_cast<int8_t*>(weight_origin.data_ptr<int8_t>());
w->unpackWeights(w_zero_points.data(), weight_ptr_int8);
// See for the subtraction 128
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp#L319
weight_origin.sub_(128);
// As of now, we are supporting only per tensor quantizer
// TO-DO : Support a per channel as well.
at::Tensor original_quantized_tensor = at::_make_per_tensor_quantized_tensor(weight_origin, weight_scales_data[0], w_zero_points[0]);
TORCH_CHECK(original_quantized_tensor.qscheme() == c10::kPerTensorAffine);
return std::tuple<at::Tensor, c10::optional<at::Tensor>>(original_quantized_tensor, bias_);
}
}
#endif // USE_PYTORCH_QNNPACK

View File

@ -50,7 +50,6 @@ struct PackedLinearWeightsQnnp : public LinearPackedParamsBase {
w_scales(std::move(w_scales)),
w_zero_points(std::move(w_zps)) {
weight_sizes = this->orig_weight.sizes().vec();
n_elements = std::accumulate(std::begin(weight_sizes), std::end(weight_sizes), 1, std::multiplies<double>());
}
std::unique_ptr<qnnpack::PackBMatrix> w;
@ -62,7 +61,6 @@ struct PackedLinearWeightsQnnp : public LinearPackedParamsBase {
std::vector<uint8_t> w_zero_points;
std::vector<float> requantization_scales;
std::vector<int64_t> weight_sizes;
int n_elements;
at::Tensor apply(
at::Tensor input,

View File

@ -66,9 +66,9 @@ class PackBMatrix final {
return packed_weights_;
}
uint8_t* unpackWeights(
void unpackWeights(
const uint8_t* kernel_zero_points,
int n_elements
int8_t* kernel
) const;
size_t getInputChannels() const

View File

@ -32,7 +32,6 @@ PackBMatrix::PackBMatrix(
const uint32_t n_stride = (output_channels + (nr - 1)) & -nr;
const uint32_t k_stride = (input_channels + (kr - 1)) & -kr;
input_channels_ = input_channels;
output_channels_ = output_channels;
packed_weights_ =

View File

@ -8,9 +8,9 @@
namespace qnnpack {
// For runtime quantization unpacking.
uint8_t* PackBMatrix::unpackWeights(
void PackBMatrix::unpackWeights(
const uint8_t* kernel_zero_points,
int n_elements
int8_t* kernel
) const {
union {
void* const as_void_ptr;
@ -18,8 +18,6 @@ uint8_t* PackBMatrix::unpackWeights(
int32_t* as_int32_ptr;
} packed = {packed_weights_};
uint8_t* kernel = (uint8_t*)malloc(n_elements * sizeof(uint8_t));;
// C = A * B
// A = M*K
// B = K*N
@ -67,7 +65,6 @@ uint8_t* PackBMatrix::unpackWeights(
}
}
return kernel;
}
} // namespace qnnpack

View File

@ -1463,6 +1463,7 @@ def define_buck_targets(
"torch/csrc/jit/mobile/train/random.cpp",
"torch/csrc/jit/mobile/train/sequential.cpp",
":gen_aten_libtorch[autograd/generated/Functions.cpp]",
"torch/csrc/quantized/quantized_backward.cpp",
],
compiler_flags = get_pt_compiler_flags(),
exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],

View File

@ -0,0 +1,77 @@
#include <ATen/native/quantized/PackedParams.h>
#include <torch/library.h>
#include <torch/torch.h>
namespace {
using namespace torch::autograd;
using namespace at;
// This class is a custom gradient function that enables quantized tensor to
// pass input gradient back to the previous layers This function can be used
// when the user is adapting mixed precision for traninig after quantization
// From torch layer, we have no access to linear_dynamic operator which needs to
// access via redispatching mechanism TO-DO : currently we are supporting per
// tensor quantization only, will expand to per channel later on
class PackedLinearWeightDynamicBackward
: public Function<PackedLinearWeightDynamicBackward> {
public:
static torch::Tensor forward(
AutogradContext* ctx,
at::Tensor input,
const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
bool reduce_range) {
static auto op =
at::Dispatcher::singleton()
.findSchemaOrThrow("quantized::linear_dynamic", "")
.typed<at::Tensor(
at::Tensor,
c10::intrusive_ptr<
LinearPackedParamsBase,
c10::detail::intrusive_target_default_null_type<
LinearPackedParamsBase>> const&,
bool)>();
auto output = op.redispatch(
DispatchKeySet({DispatchKey::CPU}), input, packed_weight, reduce_range);
// TO-DO: passing packed_weight as saved_data requires more work in adding
// LinearPackedParamsBase in ivalue For now, we can simply pass a weight
// itself. Referenced :
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/ivalue.h
auto unpacked_parameters = packed_weight->unpack();
ctx->saved_data["weight"] = std::get<0>(unpacked_parameters);
return output;
}
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
auto original_weight = ctx->saved_data["weight"].toTensor();
original_weight = at::permute(original_weight, {1, 0});
auto grad_output = grad_outputs[0];
static auto op = at::Dispatcher::singleton()
.findSchemaOrThrow("quantized::linear_prepack", "")
.typed<c10::intrusive_ptr<LinearPackedParamsBase>(
at::Tensor, c10::optional<at::Tensor>)>();
auto prepacked_weight = op.call(original_weight, nullopt);
auto grad_input = prepacked_weight->apply_dynamic(grad_output);
return {grad_input, torch::Tensor(), torch::Tensor()};
}
};
at::Tensor packed_linear_weight_grad(
c10::DispatchKeySet ks,
at::Tensor input,
const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
bool reduce_range) {
return PackedLinearWeightDynamicBackward::apply(
input, packed_weight, reduce_range);
}
} // namespace
namespace at {
namespace native {
namespace {
TORCH_LIBRARY_IMPL(quantized, Autograd, m) {
m.impl(
TORCH_SELECTIVE_NAME("quantized::linear_dynamic"),
TORCH_FN(packed_linear_weight_grad));
}
} // namespace
} // namespace native
} // namespace at