mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Towards supporting quantized structured kernels (#74560)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74560 This PR add support for quantized tensors with "unknown quantizer", which means that we can use standard APIs like torch.empty to allocate quantized tensors, with the understanding that we will set the quantizer later. This makes meta functions applicable to quantized tensors (they will allocate with unknown quantizer and the kernel will set the quantizer later) and fixes a bug David Dang reported where structured kernels give a weird error message when you call them with quantized inputs. This is not a complete support for quantized structured kernels because I haven't actually tried porting any of the quantized implementations to structured; qadd is probably a good choice to try first as it does its broadcasting implementation using TensorIterator. My goal here is just to show that the error message is better. See also https://github.com/pytorch/pytorch/issues/52680 Signed-off-by: Edward Z. Yang <ezyangfb.com> Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D35317441 Pulled By: dzdang fbshipit-source-id: ffb85b0e06ccbcc2b01052ca6760517684048b39 (cherry picked from commit 2a54b8b7bf15912240dc2f12d2cd71dc620001e1)
This commit is contained in:
@ -1878,6 +1878,7 @@
|
||||
MkldnnCPU: empty_mkldnn
|
||||
SparseCPU, SparseCUDA: empty_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA: empty_sparse_csr
|
||||
QuantizedCPU, QuantizedCUDA: empty_unknown_quantized
|
||||
|
||||
# We do not make new_empty a composite that calls into new_empty_strided, as the strided version
|
||||
# is significantly more difficult to implement by different backends
|
||||
@ -1949,6 +1950,7 @@
|
||||
CPU: empty_strided_cpu
|
||||
CUDA: empty_strided_cuda
|
||||
Meta: empty_strided_meta
|
||||
QuantizedCPU, QuantizedCUDA: empty_strided_unknown_quantized
|
||||
|
||||
- func: erf(Tensor self) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
|
@ -66,6 +66,40 @@ Tensor empty_per_channel_affine_quantized(
|
||||
quantizer);
|
||||
}
|
||||
|
||||
Tensor empty_unknown_quantized(
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
// See [Note: hacky wrapper removal for TensorOptions]
|
||||
TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
|
||||
|
||||
TORCH_CHECK(
|
||||
!(options_.has_memory_format() && optional_memory_format.has_value()),
|
||||
"Cannot set memory_format both in TensorOptions and explicit argument; please delete "
|
||||
"the redundant setter.");
|
||||
auto options = options_.merge_memory_format(optional_memory_format);
|
||||
TORCH_CHECK(
|
||||
options.has_dtype(),
|
||||
"Must provide data type for Tensor creation functions.");
|
||||
QuantizerPtr quantizer = make_unknown_quantizer(typeMetaToScalarType(options.dtype()));
|
||||
return new_qtensor(size, options, quantizer);
|
||||
}
|
||||
|
||||
Tensor empty_strided_unknown_quantized(
|
||||
IntArrayRef size,
|
||||
IntArrayRef strided,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
|
||||
TORCH_CHECK(false, "empty_strided not supported on quantized tensors yet see https://github.com/pytorch/pytorch/issues/74540")
|
||||
|
||||
}
|
||||
|
||||
// Provide better error message if dtype is wrong
|
||||
Tensor empty_affine_quantized_other_backends_stub(
|
||||
IntArrayRef,
|
||||
|
@ -417,4 +417,23 @@ Tensor from_blob_quantized_per_channel_affine(
|
||||
return qtensor;
|
||||
}
|
||||
|
||||
Tensor UnknownQuantizer::quantize(const Tensor& tensor) {
|
||||
TORCH_INTERNAL_ASSERT(false, "cannot call quantize on UnknownQuantizer");
|
||||
}
|
||||
Tensor UnknownQuantizer::dequantize(const Tensor& qtensor) {
|
||||
TORCH_INTERNAL_ASSERT(false, "cannot call dequantize on UnknownQuantizer");
|
||||
}
|
||||
Tensor& UnknownQuantizer::dequantize_out(Tensor& rtensor, const Tensor& qtensor) {
|
||||
TORCH_INTERNAL_ASSERT(false, "cannot call dequantize_out on UnknownQuantizer");
|
||||
}
|
||||
QScheme UnknownQuantizer::qscheme() const {
|
||||
TORCH_INTERNAL_ASSERT(false, "cannot call qscheme on UnknownQuantizer");
|
||||
}
|
||||
bool UnknownQuantizer::equalTo(QuantizerPtr other) {
|
||||
TORCH_INTERNAL_ASSERT(false, "cannot call equalTo on UnknownQuantizer");
|
||||
}
|
||||
QuantizerPtr make_unknown_quantizer(ScalarType scalar_type) {
|
||||
return c10::make_intrusive<UnknownQuantizer>(scalar_type);
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
@ -18,6 +18,23 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
/**
|
||||
* UnknownQuantizer is a placeholder quantizer for functions that implement
|
||||
* quantization in a two step process. First a tensor is allocated but with
|
||||
* unknown quantizer, and then the quantization kernel decides what the final
|
||||
* quantizer will be.
|
||||
*/
|
||||
struct TORCH_API UnknownQuantizer : public Quantizer {
|
||||
explicit UnknownQuantizer(ScalarType scalar_type)
|
||||
: Quantizer(scalar_type) {}
|
||||
|
||||
Tensor quantize(const Tensor& tensor) override;
|
||||
Tensor dequantize(const Tensor& qtensor) override;
|
||||
Tensor& dequantize_out(Tensor& rtensor, const Tensor& qtensor) override;
|
||||
QScheme qscheme() const override;
|
||||
bool equalTo(QuantizerPtr other) override;
|
||||
};
|
||||
|
||||
/**
|
||||
* UniformQuantizer is the parent class for all uniform quantizers.
|
||||
* These quantization scheme will map float value uniformly to
|
||||
@ -222,6 +239,8 @@ TORCH_API QuantizerPtr make_per_channel_affine_quantizer(
|
||||
int64_t axis,
|
||||
ScalarType scalar_type);
|
||||
|
||||
TORCH_API QuantizerPtr make_unknown_quantizer(ScalarType scalar_type);
|
||||
|
||||
// Create a Quantized Tensor given arguments for normal Tensor and a quantizer
|
||||
TORCH_API Tensor new_qtensor(
|
||||
IntArrayRef sizes,
|
||||
|
@ -935,5 +935,20 @@ CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd
|
||||
r"Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed"):
|
||||
dispatcher.register(["CompositeExplicitAutograd", "CompositeImplicitAutograd"])
|
||||
|
||||
def test_quantized_structured_not_implemented(self):
|
||||
x = torch.zeros([1, 1, 1])
|
||||
y = torch.zeros([1, 1, 1])
|
||||
scale, zero_point = 1.0, 0
|
||||
dtype = torch.qint8
|
||||
qx = torch.quantize_per_tensor(x, scale, zero_point, dtype)
|
||||
qy = torch.quantize_per_tensor(y, scale, zero_point, dtype)
|
||||
# If bmm gets quantized support you need to update this to something
|
||||
# else that is not implemented
|
||||
self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Could not run 'aten::bmm.out' with arguments from the 'QuantizedCPU' backend.",
|
||||
lambda: torch.bmm(qx, qy)
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -62,30 +62,15 @@ def gen_create_out_helper(backend_index: BackendIndex) -> List[str]:
|
||||
dispatch = str(backend_index.dispatch_key).lower()
|
||||
empty_impl = f"at::detail::empty_{dispatch}"
|
||||
empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
|
||||
runtime_empty_supported_check = ""
|
||||
elif backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd:
|
||||
elif backend_index.dispatch_key in (
|
||||
DispatchKey.CompositeExplicitAutograd, DispatchKey.QuantizedCPU, DispatchKey.QuantizedCUDA):
|
||||
empty_impl = "at::empty"
|
||||
empty_strided_impl = "at::empty_strided"
|
||||
runtime_empty_supported_check = """\
|
||||
if (!c10::detail::backend_supports_empty_operator(options)) {{
|
||||
// The main purpose of this CompositeExplicitAutograd kernel is to provide
|
||||
// a "free" implementation of out-of-place operators.
|
||||
// If a backend hasn't implemented an out-of-place op but has implemented
|
||||
// the out= variant, then this kernel will call their out= variant.
|
||||
// It does that by using at::empty() to create the tensor to pass to the out= variant though,
|
||||
// so this "default" kernel doesn't actually handle backends that don't support at::empty
|
||||
// (e.g. quantized backends).
|
||||
// Returning an undefined tensor here allows us to reach the out= kernel and give a better error.
|
||||
// Longer term, this could be better fixed by https://github.com/pytorch/pytorch/issues/52680
|
||||
return at::Tensor();
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
return []
|
||||
|
||||
return [f"""
|
||||
Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
|
||||
{runtime_empty_supported_check}
|
||||
if (strides.empty()) {{
|
||||
return {empty_impl}(sizes, {empty_options});
|
||||
}} else {{
|
||||
|
Reference in New Issue
Block a user