mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[tensorexpr][nnc] Support quantization (#66676)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66676 Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D31676329 Pulled By: IvanKobzarev fbshipit-source-id: 288b41ff4ed603dfaacb465f296997f14bb23c22
This commit is contained in:
committed by
Facebook GitHub Bot
parent
97f29bda59
commit
7fbcf79684
@ -393,6 +393,7 @@ namespace c10 {
|
||||
_(aten, hardswish_) \
|
||||
_(aten, hardsigmoid_) \
|
||||
_(aten, hardtanh_) \
|
||||
_(aten, quantize_per_tensor) \
|
||||
_(aten, dequantize) \
|
||||
FORALL_ATEN_BASE_SYMBOLS(_) \
|
||||
_(onnx, Add) \
|
||||
|
@ -309,4 +309,9 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point){
|
||||
return qadd<false>(qa, qb, scale, zero_point);
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
8
aten/src/ATen/native/quantized/cpu/qadd.h
Normal file
8
aten/src/ATen/native/quantized/cpu/qadd.h
Normal file
@ -0,0 +1,8 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
TORCH_API Tensor
|
||||
quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point);
|
||||
}
|
||||
} // namespace at
|
@ -5,13 +5,20 @@
|
||||
#include <torch/csrc/jit/backends/backend_detail.h>
|
||||
#include <torch/csrc/jit/backends/backend_preprocess.h>
|
||||
#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||
#include <torch/csrc/jit/passes/peephole.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
|
||||
C10_DEFINE_string(model, "", "The torch script model to optimize.");
|
||||
C10_DEFINE_string(model_name, "", "The name of the model.");
|
||||
C10_DEFINE_string(model_version, "", "The version of the model.");
|
||||
@ -166,7 +173,25 @@ int main(int argc, char** argv) {
|
||||
m.eval();
|
||||
auto frozen_m = torch::jit::freeze_module(m.clone());
|
||||
auto graph = frozen_m.get_method(FLAGS_method_name).graph();
|
||||
auto input_shapes = parseInputShapes();
|
||||
std::vector<c10::optional<at::Tensor>> example_inputs;
|
||||
example_inputs.reserve(input_shapes.size());
|
||||
for (const auto& input_shape : input_shapes) {
|
||||
example_inputs.emplace_back(at::rand(input_shape));
|
||||
}
|
||||
|
||||
torch::jit::RemoveTensorMutation(graph);
|
||||
torch::jit::EliminateDeadCode(graph->block());
|
||||
graph = torch::jit::tensorexpr::removeUnusedSelfArgument(graph);
|
||||
|
||||
torch::jit::tensorexpr::annotateInputShapes(graph, example_inputs);
|
||||
torch::jit::OptimizeFrozenGraph(graph, true);
|
||||
torch::jit::PropagateShapesOnGraph(graph);
|
||||
torch::jit::PeepholeOptimize(graph, false);
|
||||
torch::jit::ConstantPropagation(graph);
|
||||
torch::jit::PropagateShapesOnGraph(graph);
|
||||
torch::jit::PeepholeOptimize(graph, false);
|
||||
torch::jit::ConstantPropagation(graph);
|
||||
|
||||
auto compile_spec = createCompileSpec();
|
||||
auto any_dict_ty =
|
||||
|
@ -15,6 +15,7 @@ set(TENSOREXPR_TEST_SRCS
|
||||
${TENSOREXPR_TEST_ROOT}/test_loopnest.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_memdependency.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_ops.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_quantization.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_reductions.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_registerizer.cpp
|
||||
${TENSOREXPR_TEST_ROOT}/test_simplify.cpp
|
||||
|
231
test/cpp/tensorexpr/test_quantization.cpp
Normal file
231
test/cpp/tensorexpr/test_quantization.cpp
Normal file
@ -0,0 +1,231 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/native/quantized/cpu/conv_packed_params.h>
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include <torch/torch.h>
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
#include "torch/csrc/jit/tensorexpr/eval.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
|
||||
using namespace torch::indexing;
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
class Quantization : public ::testing::Test {
|
||||
public:
|
||||
// NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
|
||||
void SetUp() {
|
||||
getTEMustUseLLVMOnCPU() = false;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(Quantization, QuantDequantInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=12]()
|
||||
%3 : int = prim::Constant[value=13]()
|
||||
%4 : float = prim::Constant[value=0.1]()
|
||||
%q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
|
||||
%6 : Float(2, 2) = aten::dequantize(%q.1)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8);
|
||||
auto y_expected = at::dequantize(q);
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantDequantUInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=13]()
|
||||
%3 : int = prim::Constant[value=122]()
|
||||
%4 : float = prim::Constant[value=0.1]()
|
||||
%q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
|
||||
%6 : Float(2, 2) = aten::dequantize(%q.1)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
|
||||
auto y_expected = at::dequantize(q);
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
at::Tensor quantized_add(
|
||||
at::Tensor x1,
|
||||
at::Tensor x2,
|
||||
double scale,
|
||||
int64_t zero) {
|
||||
const auto qadd_op =
|
||||
c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("quantized::add", "")
|
||||
.typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
|
||||
return qadd_op.call(x1, x2, scale, zero);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantAddDequantInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=12]()
|
||||
%qz1 : int = prim::Constant[value=13]()
|
||||
%qs1 : float = prim::Constant[value=0.1]()
|
||||
%qz2 : int = prim::Constant[value=13]()
|
||||
%qs2 : float = prim::Constant[value=0.1]()
|
||||
%qza : int = prim::Constant[value=13]()
|
||||
%qsa : float = prim::Constant[value=0.1]()
|
||||
%q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
|
||||
%q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
|
||||
%qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
|
||||
%6 : Float(2, 2) = aten::dequantize(%qa)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8);
|
||||
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8);
|
||||
auto qa = quantized_add(q1, q2, 0.1f, 13);
|
||||
auto y_expected = at::dequantize(qa);
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x1, x2};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "x1:\n" << x1 << std::endl;
|
||||
std::cout << "q1:\n" << q1 << std::endl;
|
||||
std::cout << "x2:\n" << x2 << std::endl;
|
||||
std::cout << "q2:\n" << q2 << std::endl;
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantAddDequantUInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=13]()
|
||||
%qz1 : int = prim::Constant[value=13]()
|
||||
%qs1 : float = prim::Constant[value=0.1]()
|
||||
%qz2 : int = prim::Constant[value=13]()
|
||||
%qs2 : float = prim::Constant[value=0.1]()
|
||||
%qza : int = prim::Constant[value=13]()
|
||||
%qsa : float = prim::Constant[value=0.1]()
|
||||
%q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
|
||||
%q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
|
||||
%qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
|
||||
%6 : Float(2, 2) = aten::dequantize(%qa)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
|
||||
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
|
||||
auto qa = quantized_add(q1, q2, 0.1f, 13);
|
||||
auto y_expected = at::dequantize(qa);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x1, x2};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "x1:\n" << x1 << std::endl;
|
||||
std::cout << "q1:\n" << q1 << std::endl;
|
||||
std::cout << "x2:\n" << x2 << std::endl;
|
||||
std::cout << "q2:\n" << q2 << std::endl;
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) {
|
||||
const auto graph_string = R"IR(
|
||||
graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)):
|
||||
%2 : int = prim::Constant[value=13]()
|
||||
%4 : NoneType = prim::Constant()
|
||||
%3 : int[] = prim::Constant[value=[4, 4]]()
|
||||
%qz : int = prim::Constant[value=13]()
|
||||
%qs : float = prim::Constant[value=0.1]()
|
||||
%q : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qs, %qz, %2)
|
||||
%qu : QUInt8(1, 1, 4, 4) = aten::upsample_nearest2d(%q, %3, %4)
|
||||
%6 : Float(1, 1, 4, 4) = aten::dequantize(%qu)
|
||||
return (%6))IR";
|
||||
auto graph = std::make_shared<Graph>();
|
||||
parseIR(graph_string, &*graph);
|
||||
|
||||
auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
||||
auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8);
|
||||
auto qu = at::upsample_nearest2d(q, {4, 4});
|
||||
auto y_expected = at::dequantize(qu);
|
||||
|
||||
TensorExprKernel k(graph);
|
||||
std::vector<at::Tensor> inputs = {x};
|
||||
StmtPtr s = k.getCodeGenStmt();
|
||||
|
||||
std::vector<IValue> stack = fmap<IValue>(inputs);
|
||||
k.run(stack);
|
||||
auto y = stack[0].toTensor();
|
||||
bool check = at::allclose(y_expected, y);
|
||||
if (!check) {
|
||||
std::cout << "x:\n" << x << std::endl;
|
||||
std::cout << "q:\n" << q << std::endl;
|
||||
std::cout << "qu:\n" << qu << std::endl;
|
||||
std::cout << "y_expected:\n" << y_expected << std::endl;
|
||||
std::cout << "y:\n" << y << std::endl;
|
||||
}
|
||||
CHECK_EQ(check, 1);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -1992,6 +1992,7 @@ works_list = [
|
||||
'ceil',
|
||||
'clamp',
|
||||
'clamp.scalar',
|
||||
'contiguous',
|
||||
'cos',
|
||||
'cosh',
|
||||
'div.no_rounding_mode',
|
||||
|
@ -323,6 +323,7 @@ core_sources_full_mobile_no_backend_interface = [
|
||||
"torch/csrc/jit/tensorexpr/operators/misc.cpp",
|
||||
"torch/csrc/jit/tensorexpr/operators/norm.cpp",
|
||||
"torch/csrc/jit/tensorexpr/operators/pointwise.cpp",
|
||||
"torch/csrc/jit/tensorexpr/operators/quantization.cpp",
|
||||
"torch/csrc/jit/tensorexpr/operators/reduction.cpp",
|
||||
"torch/csrc/jit/tensorexpr/operators/softmax.cpp",
|
||||
"torch/csrc/jit/tensorexpr/reduction.cpp",
|
||||
|
@ -462,7 +462,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
std::vector<DstType> dst_values(src_values.size());
|
||||
for (int i = 0; i < src_dtype.lanes(); ++i) {
|
||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||
dst_values[i] = static_cast<DstType>(src_values[i]);
|
||||
dst_values[i] = static_cast<DstType>(underlyingValue(src_values[i]));
|
||||
}
|
||||
return dst_values;
|
||||
}
|
||||
@ -479,6 +479,19 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DST_TYPE_CASE);
|
||||
#undef DST_TYPE_CASE
|
||||
#define DST_TYPE_CASE_QUANT(Type, Name, CppType) \
|
||||
case ScalarType::Name: { \
|
||||
std::vector<CppType> vec = castValues<SrcType, CppType>(dst_dtype, v); \
|
||||
std::vector<Type> qvec; \
|
||||
qvec.reserve(vec.size()); \
|
||||
for (CppType u : vec) { \
|
||||
qvec.emplace_back(u); \
|
||||
} \
|
||||
this->value_ = InterpValue(qvec); \
|
||||
} break;
|
||||
DST_TYPE_CASE_QUANT(c10::quint8, QUInt8, uint8_t)
|
||||
DST_TYPE_CASE_QUANT(c10::qint8, QInt8, int8_t)
|
||||
#undef DST_TYPE_CASE_QUANT
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
}
|
||||
@ -500,6 +513,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
doCastFromSrc<Type>(src_dtype, dst_dtype, value_); \
|
||||
break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, SRC_TYPE_CASE);
|
||||
SRC_TYPE_CASE(c10::quint8, QUInt8);
|
||||
SRC_TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef SRC_TYPE_CASE
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
@ -683,14 +698,18 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
", idx=", \
|
||||
index[i], \
|
||||
", val=", \
|
||||
(int)val[i]); \
|
||||
(int)underlyingValue(val[i])); \
|
||||
} \
|
||||
value_ = InterpValue(val); \
|
||||
} break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
throw unsupported_dtype("scalar type:" + std::to_string(v_sdtype));
|
||||
}
|
||||
}
|
||||
|
||||
@ -725,11 +744,15 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
", idx=", \
|
||||
index[i], \
|
||||
", val=", \
|
||||
(int)value[i]); \
|
||||
(int)underlyingValue(value[i])); \
|
||||
ptr##Name[index[i]] = value[i]; \
|
||||
} \
|
||||
} break;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
@ -773,6 +796,12 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
val = value().as<int64_t>();
|
||||
} else if (value().dtype() == kInt) {
|
||||
val = value().intValue();
|
||||
} else if (value().dtype() == kDouble) {
|
||||
auto x = value().as<double>();
|
||||
val = reinterpret_cast<int64_t*>(&x)[0];
|
||||
} else if (value().dtype() == kFloat) {
|
||||
auto x = value().as<float>();
|
||||
val = reinterpret_cast<int64_t*>(&x)[0];
|
||||
} else {
|
||||
throw malformed_input(
|
||||
"extra_args in ExternalCalls must have int64 dtype", v);
|
||||
|
@ -49,11 +49,25 @@ class InterpValue {
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_CTOR);
|
||||
#undef VALUE_CTOR
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
explicit InterpValue(c10::quint8 v) : dtype_(kQUInt8) {
|
||||
QUInt8values.emplace_back(v.val_);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
explicit InterpValue(c10::qint8 v) : dtype_(kQInt8) {
|
||||
QInt8values.emplace_back(v.val_);
|
||||
}
|
||||
|
||||
#define VALUE_VEC_CTOR(Type, Name) \
|
||||
InterpValue(const std::vector<Type>& v) \
|
||||
: dtype_(Dtype(k##Name, v.size())), Name##values(v) {}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_VEC_CTOR);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
VALUE_VEC_CTOR(c10::quint8, QUInt8)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
VALUE_VEC_CTOR(c10::qint8, QInt8)
|
||||
#undef VALUE_VEC_CTOR
|
||||
|
||||
template <typename T>
|
||||
@ -73,6 +87,8 @@ class InterpValue {
|
||||
|
||||
#define VALUE_STORAGE(Type, Name) std::vector<Type> Name##values;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_STORAGE);
|
||||
VALUE_STORAGE(c10::qint8, QInt8);
|
||||
VALUE_STORAGE(c10::quint8, QUInt8);
|
||||
#undef VALUE_STORAGE
|
||||
void* ptr;
|
||||
};
|
||||
@ -86,6 +102,8 @@ class InterpValue {
|
||||
return Name##values[0]; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH);
|
||||
VALUE_AS_DISPATCH(c10::quint8, QUInt8);
|
||||
VALUE_AS_DISPATCH(c10::qint8, QInt8);
|
||||
#undef VALUE_AS_DISPATCH
|
||||
|
||||
#define VALUE_AS_VEC_DISPATCH(Type, Name) \
|
||||
@ -97,8 +115,25 @@ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH);
|
||||
return Name##values; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_VEC_DISPATCH);
|
||||
VALUE_AS_VEC_DISPATCH(c10::quint8, QUInt8);
|
||||
VALUE_AS_VEC_DISPATCH(c10::qint8, QInt8);
|
||||
#undef VALUE_AS_VEC_DISPATCH
|
||||
|
||||
template <typename Type>
|
||||
auto underlyingValue(Type x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline auto underlyingValue<c10::quint8>(c10::quint8 x) {
|
||||
return x.val_;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline auto underlyingValue<c10::qint8>(c10::qint8 x) {
|
||||
return x.val_;
|
||||
}
|
||||
|
||||
template <typename To, typename From>
|
||||
To raw_bitcast(const From& src) {
|
||||
TORCH_CHECK(sizeof(To) == sizeof(From), "Invalid bitcast invocation");
|
||||
@ -204,6 +239,10 @@ class ExprEval {
|
||||
} break;
|
||||
// NOLINTNEXTLINE(modernize-use-emplace)
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
|
||||
// NOLINTNEXTLINE(modernize-use-emplace)
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
// NOLINTNEXTLINE(modernize-use-emplace)
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
case ScalarType::Bool: {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
@ -229,6 +268,8 @@ class ExprEval {
|
||||
ret_value_ = InterpValue(ret_val_arg[0]); \
|
||||
} break;
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
case ScalarType::Bool: {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
|
@ -1,9 +1,14 @@
|
||||
#include <torch/csrc/jit/tensorexpr/external_functions.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/quantized/cpu/conv_packed_params.h>
|
||||
#include <ATen/native/quantized/cpu/qadd.h>
|
||||
#include <ATen/native/xnnpack/OpContext.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/tensorexpr/exceptions.h>
|
||||
#include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
|
||||
@ -99,6 +104,200 @@ void nnc_aten_conv2d(
|
||||
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
|
||||
}
|
||||
|
||||
void nnc_aten_quantized_conv2d(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
int64_t* buf_ranks,
|
||||
int64_t* buf_dims,
|
||||
int8_t* buf_dtypes,
|
||||
int64_t,
|
||||
int64_t* extra_args) {
|
||||
std::vector<at::Tensor> tensors =
|
||||
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
const double x_qscale = ((double*)extra_args)[0];
|
||||
const int64_t x_qzero = extra_args[1];
|
||||
const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
|
||||
at::Tensor qx = at::from_blob_quantized_per_tensor_affine(
|
||||
buf_data[1],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
tensors[1].sizes(),
|
||||
[](void*) {},
|
||||
// NOLINTNEXTLINE
|
||||
x_qscale,
|
||||
x_qzero,
|
||||
at::TensorOptions(toQIntType(x_qdtype)));
|
||||
auto convPackedParams =
|
||||
reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
|
||||
const double out_qscale = ((double*)extra_args)[3];
|
||||
const int64_t out_qzero = extra_args[4];
|
||||
auto r = convPackedParams->apply(qx, out_qscale, out_qzero);
|
||||
r = r.contiguous();
|
||||
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
|
||||
}
|
||||
|
||||
void nnc_aten_quantized_conv2d_relu(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
int64_t* buf_ranks,
|
||||
int64_t* buf_dims,
|
||||
int8_t* buf_dtypes,
|
||||
int64_t,
|
||||
int64_t* extra_args) {
|
||||
std::vector<at::Tensor> tensors =
|
||||
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
const double x_qscale = ((double*)extra_args)[0];
|
||||
const int64_t x_qzero = extra_args[1];
|
||||
const c10::ScalarType x_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
|
||||
at::Tensor qx = at::from_blob_quantized_per_tensor_affine(
|
||||
buf_data[1],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
tensors[1].sizes(),
|
||||
[](void*) {},
|
||||
// NOLINTNEXTLINE
|
||||
x_qscale,
|
||||
x_qzero,
|
||||
at::TensorOptions(toQIntType(x_qdtype)));
|
||||
auto convPackedParams =
|
||||
reinterpret_cast<ConvPackedParamsBase<2>*>(buf_data[2]);
|
||||
const double out_qscale = ((double*)extra_args)[3];
|
||||
const int64_t out_qzero = extra_args[4];
|
||||
auto r = convPackedParams->apply_relu(qx, out_qscale, out_qzero);
|
||||
r = r.contiguous();
|
||||
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
|
||||
}
|
||||
|
||||
void nnc_aten_quantized_add(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
int64_t* buf_ranks,
|
||||
int64_t* buf_dims,
|
||||
int8_t* buf_dtypes,
|
||||
int64_t,
|
||||
int64_t* extra_args) {
|
||||
std::vector<at::Tensor> tensors =
|
||||
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
|
||||
const double a_qscale = ((double*)extra_args)[0];
|
||||
const int64_t a_qzero = extra_args[1];
|
||||
const c10::ScalarType a_qdtype = static_cast<c10::ScalarType>(extra_args[2]);
|
||||
at::Tensor qa = at::from_blob_quantized_per_tensor_affine(
|
||||
buf_data[1],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
tensors[1].sizes(),
|
||||
[](void*) {},
|
||||
// NOLINTNEXTLINE
|
||||
a_qscale,
|
||||
a_qzero,
|
||||
at::TensorOptions(toQIntType(a_qdtype)));
|
||||
const double b_qscale = ((double*)extra_args)[3];
|
||||
const int64_t b_qzero = extra_args[4];
|
||||
const c10::ScalarType b_qdtype = static_cast<c10::ScalarType>(extra_args[5]);
|
||||
at::Tensor qb = at::from_blob_quantized_per_tensor_affine(
|
||||
buf_data[2],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
tensors[2].sizes(),
|
||||
[](void*) {},
|
||||
// NOLINTNEXTLINE
|
||||
b_qscale,
|
||||
b_qzero,
|
||||
at::TensorOptions(toQIntType(b_qdtype)));
|
||||
const double out_qscale = ((double*)extra_args)[6];
|
||||
const int64_t out_qzero = extra_args[7];
|
||||
auto r = at::native::quantized_add(qa, qb, out_qscale, out_qzero);
|
||||
r = r.contiguous();
|
||||
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
|
||||
}
|
||||
|
||||
void nnc_aten_upsample_nearest2d(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
int64_t* buf_ranks,
|
||||
int64_t* buf_dims,
|
||||
int8_t* buf_dtypes,
|
||||
int64_t,
|
||||
int64_t* extra_args) {
|
||||
std::vector<at::Tensor> tensors =
|
||||
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
at::Tensor x = tensors[0];
|
||||
const double x_qscale = ((double*)extra_args)[0];
|
||||
const int64_t x_qzero = extra_args[1];
|
||||
const int64_t x_qdtype = extra_args[2];
|
||||
const auto is_quantized = x_qdtype != -1;
|
||||
if (is_quantized) {
|
||||
x = at::from_blob_quantized_per_tensor_affine(
|
||||
buf_data[1],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
tensors[1].sizes(),
|
||||
[](void*) {},
|
||||
// NOLINTNEXTLINE
|
||||
x_qscale,
|
||||
x_qzero,
|
||||
at::TensorOptions(toQIntType(static_cast<c10::ScalarType>(x_qdtype))));
|
||||
}
|
||||
|
||||
int64_t output_size_h = extra_args[3];
|
||||
int64_t output_size_w = extra_args[4];
|
||||
double scale_factor_h = ((double*)extra_args)[5];
|
||||
double scale_factor_w = ((double*)extra_args)[6];
|
||||
|
||||
auto r = at::upsample_nearest2d(
|
||||
x,
|
||||
(output_size_h != -1)
|
||||
? c10::optional<at::IntArrayRef>({output_size_h, output_size_w})
|
||||
: c10::nullopt,
|
||||
(scale_factor_h != -1.f) ? c10::optional<at::ArrayRef<double>>(
|
||||
{scale_factor_h, scale_factor_w})
|
||||
: c10::nullopt);
|
||||
r = r.contiguous();
|
||||
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
|
||||
}
|
||||
|
||||
void nnc_aten_quantize_per_tensor(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
int64_t* buf_ranks,
|
||||
int64_t* buf_dims,
|
||||
int8_t* buf_dtypes,
|
||||
int64_t,
|
||||
int64_t* extra_args) {
|
||||
std::vector<at::Tensor> tensors =
|
||||
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
at::Tensor x = tensors[1];
|
||||
const double qscale = ((double*)extra_args)[0];
|
||||
const int64_t qzero = extra_args[1];
|
||||
const c10::ScalarType qdtype = static_cast<c10::ScalarType>(extra_args[2]);
|
||||
auto r = at::quantize_per_tensor(x, qscale, qzero, qdtype);
|
||||
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
|
||||
}
|
||||
|
||||
void nnc_aten_dequantize(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
int64_t* buf_ranks,
|
||||
int64_t* buf_dims,
|
||||
int8_t* buf_dtypes,
|
||||
int64_t,
|
||||
int64_t* extra_args) {
|
||||
std::vector<at::Tensor> tensors =
|
||||
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
const double qscale = ((double*)extra_args)[0];
|
||||
const int64_t qzero = extra_args[1];
|
||||
const int64_t qdtype = extra_args[2];
|
||||
at::Tensor qx = at::from_blob_quantized_per_tensor_affine(
|
||||
buf_data[1],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
tensors[1].sizes(),
|
||||
[](void*) {},
|
||||
// NOLINTNEXTLINE
|
||||
qscale,
|
||||
qzero,
|
||||
at::TensorOptions(toQIntType(static_cast<c10::ScalarType>(qdtype))));
|
||||
auto r = at::dequantize(qx);
|
||||
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
|
||||
}
|
||||
|
||||
void nnc_aten_adaptive_avg_pool2d(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
@ -242,6 +441,24 @@ void nnc_prepacked_conv2d_clamp_run(
|
||||
const static RegisterNNCExternalFunction nnc_conv2d(
|
||||
"nnc_aten_conv2d",
|
||||
nnc_aten_conv2d);
|
||||
const static RegisterNNCExternalFunction nnc_quantized_conv2d(
|
||||
"nnc_aten_quantized_conv2d",
|
||||
nnc_aten_quantized_conv2d);
|
||||
const static RegisterNNCExternalFunction nnc_quantized_conv2d_relu(
|
||||
"nnc_aten_quantized_conv2d_relu",
|
||||
nnc_aten_quantized_conv2d_relu);
|
||||
const static RegisterNNCExternalFunction nnc_quantized_add(
|
||||
"nnc_aten_quantized_add",
|
||||
nnc_aten_quantized_add);
|
||||
const static RegisterNNCExternalFunction nnc_quantize_per_tensor(
|
||||
"nnc_aten_quantize_per_tensor",
|
||||
nnc_aten_quantize_per_tensor);
|
||||
const static RegisterNNCExternalFunction nnc_dequantize(
|
||||
"nnc_aten_dequantize",
|
||||
nnc_aten_dequantize);
|
||||
const static RegisterNNCExternalFunction nnc_upsample_nearest2d(
|
||||
"nnc_aten_upsample_nearest2d",
|
||||
nnc_aten_upsample_nearest2d);
|
||||
const static RegisterNNCExternalFunction nnc_adaptive_avg_pool2d(
|
||||
"nnc_aten_adaptive_avg_pool2d",
|
||||
nnc_aten_adaptive_avg_pool2d);
|
||||
|
@ -13,7 +13,13 @@
|
||||
_(nnc_aten_mm) \
|
||||
_(nnc_aten_adaptive_avg_pool2d) \
|
||||
_(nnc_aten_mean) \
|
||||
_(nnc_aten_addmm)
|
||||
_(nnc_aten_addmm) \
|
||||
_(nnc_aten_quantized_conv2d) \
|
||||
_(nnc_aten_quantized_conv2d_relu) \
|
||||
_(nnc_aten_quantized_add) \
|
||||
_(nnc_aten_quantize_per_tensor) \
|
||||
_(nnc_aten_dequantize) \
|
||||
_(nnc_aten_upsample_nearest2d)
|
||||
|
||||
#define DECLARE_EXTERNAL_FUNCTION(NAME) \
|
||||
TORCH_API void NAME( \
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorGeometry.h>
|
||||
#include <c10/core/ScalarTypeToTypeMeta.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/string_utils.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
@ -341,6 +342,8 @@ ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const {
|
||||
return ArgNone();
|
||||
} else if (val.isIntList()) {
|
||||
return val.toIntVector();
|
||||
} else if (val.isDoubleList()) {
|
||||
return val.toDoubleVector();
|
||||
} else {
|
||||
throw unsupported_dtype(val.type()->str());
|
||||
}
|
||||
@ -400,6 +403,34 @@ c10::optional<ScalarType> findDtypeForValue(const torch::jit::Value* v) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
bool constZeroDimTensorAsScalarArg(
|
||||
const Value* v,
|
||||
std::vector<ArgValue>& args) {
|
||||
if (v->node()->kind() != prim::Constant || !v->type()->cast<TensorType>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto t = toIValue(v)->toTensor();
|
||||
if (t.sizes().size() != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
c10::ScalarType dtype = c10::typeMetaToScalarType(t.dtype());
|
||||
switch (dtype) {
|
||||
case ScalarType::Float:
|
||||
args.emplace_back(t.item().toFloat());
|
||||
return true;
|
||||
case ScalarType::Long:
|
||||
args.emplace_back(t.item().toLong());
|
||||
return true;
|
||||
default:
|
||||
std::stringstream ss;
|
||||
ss << "Unsupported tensor dtype:" << dtype
|
||||
<< " for converting constant 0-dim Tensor to scalar" << std::endl;
|
||||
throw unsupported_dtype(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
||||
auto inputs = v->node()->inputs();
|
||||
auto op = v->node()->kind();
|
||||
@ -420,6 +451,15 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
||||
argInputs.emplace_back(n->i(attr::chunks));
|
||||
} else if (op == aten::to) {
|
||||
argInputs.emplace_back(toArg(inputs[0]));
|
||||
} else if (op == aten::quantize_per_tensor) {
|
||||
argInputs.emplace_back(toArg(inputs[0]));
|
||||
if (!constZeroDimTensorAsScalarArg(inputs[1], argInputs)) {
|
||||
argInputs.emplace_back(toArg(inputs[1]));
|
||||
}
|
||||
if (!constZeroDimTensorAsScalarArg(inputs[2], argInputs)) {
|
||||
argInputs.emplace_back(toArg(inputs[2]));
|
||||
}
|
||||
argInputs.emplace_back(toArg(inputs[3]));
|
||||
} else if (op == aten::conv2d) {
|
||||
for (auto inp : inputs) {
|
||||
argInputs.emplace_back(toArg(inp));
|
||||
@ -1020,19 +1060,18 @@ void TensorExprKernel::bindConstant(const torch::jit::Value* v) {
|
||||
return;
|
||||
}
|
||||
auto const_tensor = toIValue(v)->toTensor();
|
||||
|
||||
auto scalar_type = c10::typeMetaToScalarType(const_tensor.options().dtype());
|
||||
const auto& tt = v->type()->expect<TensorType>();
|
||||
auto sizes = *tt->sizes().concrete_sizes();
|
||||
auto sizes = const_tensor.sizes();
|
||||
std::vector<ExprHandle> te_sizes;
|
||||
te_sizes.reserve(sizes.size());
|
||||
for (auto s : sizes) {
|
||||
te_sizes.push_back(s);
|
||||
}
|
||||
|
||||
BufPtr buf = alloc<Buf>(
|
||||
"const_" + sanitizeName(v->debugName()),
|
||||
ExprHandleVectorToExprVector(te_sizes),
|
||||
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
|
||||
ToDtype(scalar_type));
|
||||
|
||||
if (!const_tensor.is_contiguous()) {
|
||||
const_tensor = const_tensor.clone().contiguous();
|
||||
|
@ -58,6 +58,8 @@ inline std::string getArgValueName(const ArgValue& a) {
|
||||
return "bool";
|
||||
} else if (c10::get_if<BufList>(&a)) {
|
||||
return "BufList";
|
||||
} else if (c10::get_if<DoubleList>(&a)) {
|
||||
return "DoubleList";
|
||||
} else if (c10::get_if<IntList>(&a)) {
|
||||
return "IntList";
|
||||
} else if (c10::get_if<ArgNone>(&a)) {
|
||||
|
@ -495,6 +495,14 @@ llvm::Type* LLVMCodeGenImpl::dtypeToLLVM(Dtype dtype) {
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
|
||||
#undef TYPE_CASE
|
||||
case ScalarType::QInt8:
|
||||
return CharTy_;
|
||||
break;
|
||||
|
||||
case ScalarType::QUInt8:
|
||||
return ByteTy_;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
}
|
||||
@ -976,9 +984,11 @@ void LLVMCodeGenImpl::visit(CastPtr v) {
|
||||
}
|
||||
|
||||
bool destUnsigned = v->dtype().scalar_type() == ScalarType::Byte ||
|
||||
v->dtype().scalar_type() == ScalarType::QUInt8 ||
|
||||
v->dtype().scalar_type() == ScalarType::Bool;
|
||||
bool srcUnsigned =
|
||||
v->src_value()->dtype().scalar_type() == ScalarType::Byte ||
|
||||
v->src_value()->dtype().scalar_type() == ScalarType::QUInt8 ||
|
||||
v->src_value()->dtype().scalar_type() == ScalarType::Bool;
|
||||
|
||||
// Scalar casts
|
||||
|
@ -32,6 +32,9 @@ namespace {
|
||||
RegisterNNCLoweringsFunction aten_dropout(
|
||||
{"aten::dropout(Tensor input, float p, bool train) -> (Tensor)"},
|
||||
computeNoop);
|
||||
RegisterNNCLoweringsFunction aten_contiguous(
|
||||
{"aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> (Tensor(a))"},
|
||||
computeNoop);
|
||||
|
||||
// TODO: convert to schema, add a test
|
||||
// RegisterNNCLoweringsFunction prepacked_conv2d_clamp_run(
|
||||
@ -1480,6 +1483,51 @@ RegisterNNCLoweringsFunction aten_add(
|
||||
"aten_add", inputs, outputShape, outputType, add_lambda);
|
||||
});
|
||||
|
||||
#define NNC_QUANTIZATION_EXPR_QUANT 0
|
||||
#define NNC_QUANTIZATION_EXPR_DEQUANT 0
|
||||
|
||||
RegisterNNCLoweringsFunction aten_quantize_per_tensor(
|
||||
{"aten::quantize_per_tensor(Tensor self, float scale, int zero_point, int dtype) -> (Tensor)",
|
||||
"aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int dtype) -> (Tensor)",
|
||||
"aten::quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, int dtype) -> (Tensor[])"},
|
||||
#if NNC_QUANTIZATION_EXPR_QUANT == 1
|
||||
computeQuantizePerTensor
|
||||
#else
|
||||
computeQuantizePerTensorExternalCall
|
||||
#endif
|
||||
);
|
||||
|
||||
RegisterNNCLoweringsFunction aten_dequantize(
|
||||
{"aten::dequantize.self(Tensor self) -> (Tensor)"},
|
||||
#if NNC_QUANTIZATION_EXPR_DEQUANT == 1
|
||||
computeDequantize
|
||||
#else
|
||||
computeDequantizeExternalCall
|
||||
#endif
|
||||
);
|
||||
|
||||
// TODO: Fix CustomClass register for FunctionSchemeParser in internal build
|
||||
// RegisterNNCLoweringsFunction quantized_conv2d(
|
||||
// {"quantized::conv2d.new(Tensor qx,
|
||||
// __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight,
|
||||
// float output_scale, int output_zero_point) -> (Tensor)"},
|
||||
// computeQuantizedConv2d);
|
||||
|
||||
// TODO: Fix CustomClass register for FunctionSchemeParser in internal build
|
||||
// RegisterNNCLoweringsFunction quantized_conv2d_relu(
|
||||
// {"quantized::conv2d_relu.new(Tensor qx,
|
||||
// __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight,
|
||||
// float output_scale, int output_zero_point) -> (Tensor)"},
|
||||
// computeQuantizedConv2dRelu);
|
||||
|
||||
RegisterNNCLoweringsFunction quantized_add(
|
||||
{"quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> (Tensor qc)"},
|
||||
computeQuantizedAdd);
|
||||
|
||||
RegisterNNCLoweringsFunction aten_upsample_nearest2d(
|
||||
{"aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)"},
|
||||
computeUpsampleNearest2d);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -15,6 +15,7 @@ namespace tensorexpr {
|
||||
|
||||
using ArgNone = c10::monostate;
|
||||
using BufList = std::vector<tensorexpr::BufHandle>;
|
||||
using DoubleList = std::vector<double>;
|
||||
using IntList = std::vector<int64_t>;
|
||||
using ArgValue = c10::variant<
|
||||
tensorexpr::BufHandle,
|
||||
@ -23,6 +24,7 @@ using ArgValue = c10::variant<
|
||||
int64_t,
|
||||
bool,
|
||||
BufList,
|
||||
DoubleList,
|
||||
IntList,
|
||||
ArgNone>;
|
||||
|
||||
|
@ -5,5 +5,6 @@
|
||||
#include <torch/csrc/jit/tensorexpr/operators/misc.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/norm.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/pointwise.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/quantization.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/reduction.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/softmax.h>
|
||||
|
390
torch/csrc/jit/tensorexpr/operators/quantization.cpp
Normal file
390
torch/csrc/jit/tensorexpr/operators/quantization.cpp
Normal file
@ -0,0 +1,390 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/misc.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/quantization.h>
|
||||
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
namespace {
|
||||
std::vector<int64_t> _pair_int(ArgValue v) {
|
||||
if (auto t = c10::get_if<IntList>(&v)) {
|
||||
return {(*t)[0], (*t)[1]};
|
||||
}
|
||||
auto i = c10::get<int64_t>(v);
|
||||
return {i, i};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
double immQScale(const BufHandle& qx) {
|
||||
return to<DoubleImm>(IRSimplifier::simplify(qx.node()->qscale()))->value();
|
||||
}
|
||||
|
||||
int64_t immQZero(const BufHandle& qx) {
|
||||
return to<LongImm>(IRSimplifier::simplify(qx.node()->qzero()))->value();
|
||||
}
|
||||
|
||||
ScalarType immQDType(const BufHandle& qx) {
|
||||
return qx.dtype().scalar_type();
|
||||
}
|
||||
|
||||
bool isQuantized(const BufHandle& qx) {
|
||||
return qx.node()->qscale() && qx.node()->qzero();
|
||||
}
|
||||
|
||||
BufHandle makeQBufHandle(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
Dtype dtype,
|
||||
const ExprPtr qscale,
|
||||
const ExprPtr qzero) {
|
||||
BufHandle ResultBuf(name, dims, dtype);
|
||||
ResultBuf.node()->set_qscale(qscale);
|
||||
ResultBuf.node()->set_qzero(qzero);
|
||||
return ResultBuf;
|
||||
}
|
||||
|
||||
BufHandle makeQBufHandle(
|
||||
const std::string& name,
|
||||
const std::vector<ExprHandle>& dims,
|
||||
Dtype dtype,
|
||||
const double qscale,
|
||||
const int64_t qzero) {
|
||||
return makeQBufHandle(
|
||||
name,
|
||||
dims,
|
||||
dtype,
|
||||
DoubleImm::make(qscale).node(),
|
||||
LongImm::make(qzero).node());
|
||||
}
|
||||
|
||||
Tensor computeQuantizePerTensor(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>&,
|
||||
at::Device) {
|
||||
std::vector<VarPtr> vars;
|
||||
std::vector<ExprHandle> indices;
|
||||
for (const auto& os : outputShape) {
|
||||
auto var = alloc<Var>("", os.node()->dtype());
|
||||
vars.push_back(var);
|
||||
indices.push_back(VarHandle(var));
|
||||
}
|
||||
|
||||
ExprHandle qscale = constant(inputs[1]);
|
||||
ExprHandle qzero = constant(inputs[2]);
|
||||
|
||||
const auto dtype = [](auto qdtype) {
|
||||
if (static_cast<int64_t>(ScalarType::QInt8) == qdtype) {
|
||||
return Dtype(ScalarType::QInt8);
|
||||
} else if (static_cast<int64_t>(ScalarType::QUInt8) == qdtype) {
|
||||
return Dtype(ScalarType::QUInt8);
|
||||
}
|
||||
throw malformed_input("Expected quantized dtype");
|
||||
}(c10::get<int64_t>(inputs[3]));
|
||||
const BufHandle& x = c10::get<BufHandle>(inputs[0]);
|
||||
|
||||
auto x_dtype = x.node()->dtype();
|
||||
auto promoted_qscale = promoteToDtype(qscale, x_dtype.scalar_type());
|
||||
auto promoted_qzero = promoteToDtype(qzero, x_dtype.scalar_type());
|
||||
ExprHandle exprHandle = promoteToDtype(
|
||||
tensorOrConstant(inputs[0], indices) / promoted_qscale + promoted_qzero +
|
||||
FloatImm::make(0.5f),
|
||||
dtype.scalar_type());
|
||||
|
||||
BufPtr buf = alloc<Buf>(
|
||||
"quantize_per_tensor",
|
||||
ExprHandleVectorToExprVector(outputShape),
|
||||
dtype,
|
||||
nullptr,
|
||||
qscale.node(),
|
||||
qzero.node());
|
||||
return Tensor(buf, vars, exprHandle.node());
|
||||
}
|
||||
|
||||
Tensor computeQuantizePerTensorExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
const BufHandle& x = c10::get<BufHandle>(inputs[0]);
|
||||
const auto qscale = c10::get<double>(inputs[1]);
|
||||
const auto qzero = c10::get<int64_t>(inputs[2]);
|
||||
const auto qdtype = c10::get<int64_t>(inputs[3]);
|
||||
|
||||
const auto dtype = [](auto qdtype) {
|
||||
if (static_cast<int64_t>(ScalarType::QInt8) == qdtype) {
|
||||
return Dtype(ScalarType::QInt8);
|
||||
} else if (static_cast<int64_t>(ScalarType::QUInt8) == qdtype) {
|
||||
return Dtype(ScalarType::QUInt8);
|
||||
}
|
||||
throw malformed_input("Expected quantized dtype");
|
||||
}(qdtype);
|
||||
auto ResultBuf =
|
||||
makeQBufHandle("quantize_per_tensor", outputShape, dtype, qscale, qzero);
|
||||
StmtPtr s = ExternalCall::make(
|
||||
ResultBuf, "nnc_aten_quantize_per_tensor", {x}, {qscale, qzero, qdtype});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
}
|
||||
|
||||
Tensor computeDequantizeExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
Dtype dtype = kFloat;
|
||||
if (outputType) {
|
||||
dtype = Dtype(*outputType);
|
||||
}
|
||||
|
||||
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
|
||||
const double qscale = immQScale(qx);
|
||||
const int64_t qzero = immQZero(qx);
|
||||
const int64_t qdtype = (int64_t)immQDType(qx);
|
||||
|
||||
BufHandle ResultBuf("quantize", outputShape, dtype);
|
||||
StmtPtr s = ExternalCall::make(
|
||||
ResultBuf, "nnc_aten_dequantize", {qx}, {qscale, qzero, qdtype});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
}
|
||||
|
||||
Tensor computeQuantizedConv2dPrepack(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
Dtype dtype = kFloat;
|
||||
if (outputType) {
|
||||
dtype = Dtype(*outputType);
|
||||
}
|
||||
|
||||
BufHandle ResultBuf("quantized_conv2d_prepack", outputShape, dtype);
|
||||
const BufHandle& qw = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& b = c10::get<BufHandle>(inputs[1]);
|
||||
auto strides = _pair_int(inputs[2]);
|
||||
auto padding = _pair_int(inputs[3]);
|
||||
auto dilation = _pair_int(inputs[4]);
|
||||
int groups = c10::get<int64_t>(inputs[5]);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
qw.node()->qscale(),
|
||||
buildErrorMessage(
|
||||
"quantized_conv2d_prepack: Expects quantized weights, qscale is missing"));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
qw.node()->qzero(),
|
||||
buildErrorMessage(
|
||||
"quantized_conv2d_prepack: Expects quantized weights, qzero is missing"));
|
||||
StmtPtr s = ExternalCall::make(
|
||||
ResultBuf,
|
||||
"nnc_aten_quantized_conv2d_prepack",
|
||||
{qw, b},
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
{strides[0],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
strides[1],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
padding[0],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
padding[1],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
dilation[0],
|
||||
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
|
||||
dilation[1],
|
||||
groups,
|
||||
immQScale(qw),
|
||||
immQZero(qw),
|
||||
(int64_t)immQDType(qw)});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
}
|
||||
|
||||
Tensor computeQuantizedConv2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qx);
|
||||
auto ResultBuf = makeQBufHandle(
|
||||
"quantized_conv2d",
|
||||
outputShape,
|
||||
Dtype(out_qdtype),
|
||||
out_qscale,
|
||||
out_qzero);
|
||||
StmtPtr s = ExternalCall::make(
|
||||
ResultBuf,
|
||||
"nnc_aten_quantized_conv2d",
|
||||
{qx, prepacked},
|
||||
{immQScale(qx),
|
||||
immQZero(qx),
|
||||
(int64_t)immQDType(qx),
|
||||
out_qscale,
|
||||
out_qzero});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
}
|
||||
|
||||
Tensor computeQuantizedConv2dRelu(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qx = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& prepacked = c10::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qx);
|
||||
auto ResultBuf = makeQBufHandle(
|
||||
"quantized_conv2d_relu",
|
||||
outputShape,
|
||||
Dtype(out_qdtype),
|
||||
out_qscale,
|
||||
out_qzero);
|
||||
StmtPtr s = ExternalCall::make(
|
||||
ResultBuf,
|
||||
"nnc_aten_quantized_conv2d_relu",
|
||||
{qx, prepacked},
|
||||
{immQScale(qx),
|
||||
immQZero(qx),
|
||||
(int64_t)immQDType(qx),
|
||||
out_qscale,
|
||||
out_qzero});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
}
|
||||
|
||||
Tensor computeQuantizedAdd(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
// NOLINTNEXTLINE
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
// NOLINTNEXTLINE
|
||||
at::Device device) {
|
||||
const BufHandle& qa = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& qb = c10::get<BufHandle>(inputs[1]);
|
||||
const auto out_qscale = c10::get<double>(inputs[2]);
|
||||
const auto out_qzero = c10::get<int64_t>(inputs[3]);
|
||||
// Change to dtype based on outputType when dtype propagation implemented
|
||||
const auto out_qdtype = immQDType(qa);
|
||||
auto ResultBuf = makeQBufHandle(
|
||||
"quantized_add", outputShape, Dtype(out_qdtype), out_qscale, out_qzero);
|
||||
StmtPtr s = ExternalCall::make(
|
||||
ResultBuf,
|
||||
"nnc_aten_quantized_add",
|
||||
{qa, qb},
|
||||
{immQScale(qa),
|
||||
immQZero(qa),
|
||||
(int64_t)immQDType(qa),
|
||||
immQScale(qb),
|
||||
immQZero(qb),
|
||||
(int64_t)immQDType(qb),
|
||||
out_qscale,
|
||||
out_qzero});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
}
|
||||
|
||||
Tensor computeDequantize(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
Dtype dtype = kFloat;
|
||||
if (outputType) {
|
||||
dtype = Dtype(*outputType);
|
||||
}
|
||||
auto qx = c10::get<BufHandle>(inputs[0]);
|
||||
auto qscale = qx.node()->qscale();
|
||||
auto qzero = qx.node()->qzero();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
qscale, buildErrorMessage("Missing quantized scale for dequantize"));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
qzero, buildErrorMessage("Missing quantized zero point for dequantize"));
|
||||
std::vector<VarPtr> vars;
|
||||
std::vector<ExprHandle> indices;
|
||||
for (const auto& os : outputShape) {
|
||||
auto var = alloc<Var>("", os.node()->dtype());
|
||||
vars.push_back(var);
|
||||
indices.push_back(VarHandle(var));
|
||||
}
|
||||
auto qx_e_promoted =
|
||||
promoteToDtype(tensorOrConstant(inputs[0], indices), dtype.scalar_type());
|
||||
auto qscale_promoted =
|
||||
promoteToDtype(ExprHandle(qscale), dtype.scalar_type());
|
||||
auto qzero_promoted = promoteToDtype(ExprHandle(qzero), dtype.scalar_type());
|
||||
auto y = promoteToDtype(
|
||||
(qx_e_promoted - qzero_promoted) * qscale_promoted, dtype.scalar_type());
|
||||
|
||||
BufPtr buf = alloc<Buf>(
|
||||
"dequantize", ExprHandleVectorToExprVector(outputShape), dtype);
|
||||
return Tensor(buf, vars, y.node());
|
||||
}
|
||||
|
||||
Tensor computeUpsampleNearest2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device) {
|
||||
Dtype dtype = kFloat;
|
||||
if (outputType) {
|
||||
dtype = Dtype(*outputType);
|
||||
}
|
||||
int64_t output_size_h = -1;
|
||||
int64_t output_size_w = -1;
|
||||
if (auto output_sizes = c10::get_if<IntList>(&inputs[1])) {
|
||||
output_size_h = (*output_sizes)[0];
|
||||
output_size_w = (*output_sizes)[1];
|
||||
}
|
||||
|
||||
double scale_factor_h = -1.f;
|
||||
double scale_factor_w = -1.f;
|
||||
if (auto scale_factors = c10::get_if<DoubleList>(&inputs[2])) {
|
||||
scale_factor_h = (*scale_factors)[0];
|
||||
scale_factor_w = (*scale_factors)[1];
|
||||
}
|
||||
const BufHandle& x = c10::get<BufHandle>(inputs[0]);
|
||||
double qx_qscale = -1.f;
|
||||
int64_t qx_qzero = -1l;
|
||||
int64_t qx_qdtype = -1l;
|
||||
if (isQuantized(x)) {
|
||||
qx_qscale = immQScale(x);
|
||||
qx_qzero = immQZero(x);
|
||||
qx_qdtype = (int64_t)immQDType(x);
|
||||
}
|
||||
|
||||
BufHandle ResultBuf = [&]() {
|
||||
if (isQuantized(x)) {
|
||||
return makeQBufHandle(
|
||||
"upsample_nearest2d",
|
||||
outputShape,
|
||||
Dtype(immQDType(x)),
|
||||
qx_qscale,
|
||||
qx_qzero);
|
||||
}
|
||||
return BufHandle("upsample_nearest2d", outputShape, dtype);
|
||||
}();
|
||||
|
||||
StmtPtr s = ExternalCall::make(
|
||||
ResultBuf,
|
||||
"nnc_aten_upsample_nearest2d",
|
||||
{x},
|
||||
{qx_qscale,
|
||||
qx_qzero,
|
||||
qx_qdtype,
|
||||
output_size_h,
|
||||
output_size_w,
|
||||
scale_factor_h,
|
||||
scale_factor_w});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
74
torch/csrc/jit/tensorexpr/operators/quantization.h
Normal file
74
torch/csrc/jit/tensorexpr/operators/quantization.h
Normal file
@ -0,0 +1,74 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
TORCH_API ExprHandle quantizePerTensorQParamFromArg(ArgValue arg);
|
||||
|
||||
TORCH_API double immQScale(const BufHandle& qx);
|
||||
|
||||
TORCH_API int64_t immQZero(const BufHandle& qx);
|
||||
|
||||
TORCH_API ScalarType immQDType(const BufHandle& qx);
|
||||
|
||||
TORCH_API bool isQuantized(const BufHandle& qx);
|
||||
|
||||
TORCH_API Tensor computeQuantizePerTensor(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizePerTensorExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedConv2dPrepack(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedConv2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedConv2dRelu(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeQuantizedAdd(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeDequantize(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeDequantizeExternalCall(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
TORCH_API Tensor computeUpsampleNearest2d(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -17,6 +17,8 @@ Dtype Dtype::scalar_dtype() const {
|
||||
#define DTYPE_DEFINE(_1, n) TORCH_API Dtype k##n(ScalarType::n, 1);
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_DEFINE)
|
||||
DTYPE_DEFINE(c10::quint8, QUInt8);
|
||||
DTYPE_DEFINE(c10::qint8, QInt8);
|
||||
|
||||
#undef DTYPE_DEFINE
|
||||
|
||||
@ -29,6 +31,8 @@ Dtype ToDtype(ScalarType type) {
|
||||
case ScalarType::n: \
|
||||
return k##n;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
|
||||
case ScalarType::Undefined:
|
||||
@ -57,6 +61,8 @@ int Dtype::byte_size() const {
|
||||
break;
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
@ -79,6 +85,10 @@ std::string Dtype::ToCppString() const {
|
||||
return "half";
|
||||
case ScalarType::BFloat16:
|
||||
return "__nv_bfloat16";
|
||||
case ScalarType::QInt8:
|
||||
return "qint8";
|
||||
case ScalarType::QUInt8:
|
||||
return "quint8";
|
||||
default:
|
||||
throw unsupported_dtype();
|
||||
}
|
||||
|
@ -86,6 +86,8 @@ extern TORCH_API Dtype kHandle;
|
||||
#define NNC_DTYPE_DECLARATION(ctype, name) extern TORCH_API Dtype k##name;
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_DTYPE_DECLARATION)
|
||||
NNC_DTYPE_DECLARATION(c10::quint8, QUInt8);
|
||||
NNC_DTYPE_DECLARATION(c10::qint8, QInt8);
|
||||
#undef NNC_DTYPE_DECLARATION
|
||||
|
||||
template <typename T>
|
||||
@ -97,6 +99,8 @@ TORCH_API Dtype ToDtype();
|
||||
return k##name; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_TODTYPE_DECLARATION)
|
||||
NNC_TODTYPE_DECLARATION(c10::quint8, QUInt8);
|
||||
NNC_TODTYPE_DECLARATION(c10::qint8, QInt8);
|
||||
#undef NNC_TODTYPE_DECLARATION
|
||||
|
||||
TORCH_API Dtype ToDtype(ScalarType type);
|
||||
|
Reference in New Issue
Block a user