mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/138976 Approved by: https://github.com/Skylion007
62 lines
2.2 KiB
C++
62 lines
2.2 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/jit/api/module.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/passes/quantization/quantization_type.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
/** \brief Backend specific pass to fuse dequantize - op - quantize calls
|
|
* as quantized_op calls.
|
|
*
|
|
* Right now this is a fusion for fbgemm backend and only works for quantized
|
|
* conv op, we'll extend to more ops and more backends in the future.
|
|
*
|
|
* Currently supported fusion:
|
|
* q(conv2d(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_conv2d(prepack(to_nhwc(a)),
|
|
* prepack(to_nhwc(w)),
|
|
* prepack(to_nhwc(b))))
|
|
*
|
|
* q(linear(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_linear(prepack(to_nhwc(a)),
|
|
* prepack(to_nhwc(w)),
|
|
* prepack(to_nhwc(b))))
|
|
*
|
|
* \param graph the graph we want to apply fusion
|
|
*/
|
|
TORCH_API void QuantFusion(
|
|
std::shared_ptr<Graph>& graph,
|
|
QuantType quant_type = QuantType::STATIC);
|
|
|
|
/** \brief Insert prepack and unpack function in graph
|
|
* We want add pack/unpack functions for quantized weight because later we want
|
|
* to fold the packed weight as an attribute of the module, in order to reduce
|
|
* the cost of packing the weight on the fly in quantized models.
|
|
*
|
|
* Each quantized op has it's corresponding prepack/unpack function,
|
|
* right now, we only need to do prepack/unpack for quantized::linear
|
|
* and quantized::conv2d.
|
|
*/
|
|
TORCH_API void InsertPrepackUnpack(std::shared_ptr<Graph>& graph);
|
|
|
|
/** \brief Insert pack and unpack function in all graphs
|
|
* of module
|
|
*
|
|
* Go through graphs of all the methods of all child modules
|
|
* and call InsertPrepackUnpack on the graph.
|
|
*/
|
|
TORCH_API void InsertPrepackUnpack(Module& module);
|
|
|
|
TORCH_API script::Module Finalize(
|
|
script::Module& module,
|
|
QuantType quant_type = QuantType::STATIC,
|
|
const std::vector<std::string>& preserved_attrs =
|
|
std::vector<std::string>());
|
|
|
|
TORCH_API void FoldQuantizedPrepackingOps(Module& module);
|
|
|
|
TORCH_API Module FinalizeOnDevicePTQ(
|
|
Module& module,
|
|
QuantType quant_type,
|
|
const std::string& method_name);
|
|
} // namespace torch::jit
|