mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add basic torch.hash_tensor op (#154149)
Added `torch.hash_tensor` reduction function with a `mode` argument that defaults to reduction with xor. - The hash is always uint64. - Integers will be casted to uint64 before performing the xor_sum reduction - Floats will be upcasted to double and then bitcasted to uint64 before performing the xor_sum reduction Pull Request resolved: https://github.com/pytorch/pytorch/pull/154149 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
86df3ff1f1
commit
7f649ed4f8
@ -71,6 +71,8 @@
|
||||
#include <ATen/ops/exp.h>
|
||||
#include <ATen/ops/gather.h>
|
||||
#include <ATen/ops/gradient_native.h>
|
||||
#include <ATen/ops/hash_tensor.h>
|
||||
#include <ATen/ops/hash_tensor_native.h>
|
||||
#include <ATen/ops/imag.h>
|
||||
#include <ATen/ops/isnan_native.h>
|
||||
#include <ATen/ops/linalg_vector_norm.h>
|
||||
@ -398,6 +400,19 @@ TORCH_META_FUNC(amin)
|
||||
resize_reduction(*this, self, dim, keepdim, out_dtype);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(hash_tensor)
|
||||
(const Tensor& self, IntArrayRef dim, bool keepdim, int64_t mode) {
|
||||
auto maybe_result = maybe_get_output();
|
||||
if (maybe_result.defined()){
|
||||
TORCH_CHECK(maybe_result.scalar_type() == at::kUInt64, "Expected result to be of dtype uint64, but got ", maybe_result.scalar_type());
|
||||
}
|
||||
if (self.sym_numel() == 0) {
|
||||
native::zero_numel_check_dims(self, dim, "hash_tensor");
|
||||
}
|
||||
resize_reduction(*this, self, dim, keepdim, at::kUInt64);
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::meta
|
||||
|
||||
namespace at::native {
|
||||
@ -441,6 +456,7 @@ DEFINE_DISPATCH(argmin_stub);
|
||||
DEFINE_DISPATCH(cumsum_stub);
|
||||
DEFINE_DISPATCH(cumprod_stub);
|
||||
DEFINE_DISPATCH(logcumsumexp_stub);
|
||||
DEFINE_DISPATCH(xor_sum_stub);
|
||||
|
||||
Tensor _logcumsumexp_cpu(const Tensor& self, int64_t dim) {
|
||||
Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
|
||||
@ -2233,6 +2249,24 @@ Tensor dist(const Tensor &self, const Tensor& other, const Scalar& p){
|
||||
return at::norm(self - other, p);
|
||||
}
|
||||
|
||||
enum class HashMode { XOR_SUM = 0 };
|
||||
|
||||
TORCH_IMPL_FUNC(hash_tensor_out) (const Tensor& self, IntArrayRef dim, bool keepdim, int64_t mode, const Tensor& result) {
|
||||
|
||||
auto iter = meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
|
||||
switch (static_cast<HashMode>(mode)) {
|
||||
case HashMode::XOR_SUM:
|
||||
if (iter.numel() == 0) {
|
||||
result.fill_(0);
|
||||
} else {
|
||||
xor_sum_stub(iter.device_type(), iter);
|
||||
}
|
||||
return;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unknown hash_tensor mode: ", mode);
|
||||
}
|
||||
}
|
||||
|
||||
bool cpu_equal(const Tensor& self, const Tensor& other) {
|
||||
if (!at::namedinference::are_names_equal(
|
||||
self.unsafeGetTensorImpl(), other.unsafeGetTensorImpl())) {
|
||||
|
@ -27,6 +27,7 @@ DECLARE_DISPATCH(reduce_fn, min_values_stub)
|
||||
DECLARE_DISPATCH(reduce_fn, max_values_stub)
|
||||
DECLARE_DISPATCH(reduce_fn, argmax_stub)
|
||||
DECLARE_DISPATCH(reduce_fn, argmin_stub)
|
||||
DECLARE_DISPATCH(reduce_fn, xor_sum_stub)
|
||||
|
||||
using reduce_std_var_function =
|
||||
void (*)(TensorIterator&, double correction, bool take_sqrt);
|
||||
|
@ -425,6 +425,49 @@ static void argmin_kernel_impl(TensorIterator &iter) {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename acc_t = uint64_t, typename out_t = acc_t>
|
||||
struct XorSumOps {
|
||||
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
||||
if (std::is_same<scalar_t, bool>::value) {
|
||||
return acc ^ (data ? 1 : 0);
|
||||
} else if (
|
||||
std::is_same<scalar_t, float>::value ||
|
||||
std::is_same<scalar_t, double>::value ||
|
||||
std::is_same<scalar_t, at::BFloat16>::value ||
|
||||
std::is_same<scalar_t, at::Half>::value) {
|
||||
union {
|
||||
double d;
|
||||
uint64_t u;
|
||||
} converter;
|
||||
converter.d = static_cast<double>(data);
|
||||
return acc ^ converter.u;
|
||||
} else {
|
||||
return acc ^ static_cast<uint64_t>(data);
|
||||
}
|
||||
}
|
||||
|
||||
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
||||
return a ^ b;
|
||||
}
|
||||
|
||||
inline C10_DEVICE out_t project(acc_t a) const {
|
||||
return a;
|
||||
}
|
||||
|
||||
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
||||
return acc;
|
||||
}
|
||||
};
|
||||
|
||||
static void xor_sum_kernel_impl(TensorIterator& iter) {
|
||||
// Use iter.dtype(1) to dispatch based on the type of the input tensor
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.dtype(1), "xor_sum_cpu", [&] {
|
||||
binary_kernel_reduce(
|
||||
iter, XorSumOps<scalar_t>(), static_cast<uint64_t>(0));
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
REGISTER_DISPATCH(std_var_stub, &std_var_kernel_impl)
|
||||
@ -439,6 +482,7 @@ REGISTER_DISPATCH(min_values_stub, &min_values_kernel_impl)
|
||||
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_impl)
|
||||
REGISTER_DISPATCH(argmax_stub, &argmax_kernel_impl)
|
||||
REGISTER_DISPATCH(argmin_stub, &argmin_kernel_impl)
|
||||
REGISTER_DISPATCH(xor_sum_stub, &xor_sum_kernel_impl)
|
||||
|
||||
REGISTER_DISPATCH(cumprod_stub, &cumprod_cpu_kernel)
|
||||
REGISTER_DISPATCH(cumsum_stub, &cumsum_cpu_kernel)
|
||||
|
@ -154,6 +154,51 @@ struct prod_functor<c10::complex<at::Half>> {
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename scalar_t, typename enable = void>
|
||||
struct xor_sum_functor {
|
||||
void operator()(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, uint64_t>(
|
||||
iter,
|
||||
func_wrapper<uint64_t>(
|
||||
[] GPU_LAMBDA(uint64_t a, uint64_t b) -> uint64_t {
|
||||
return a ^ b;
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
struct xor_sum_functor<scalar_t, std::enable_if_t<!std::is_integral_v<scalar_t>>> {
|
||||
void operator()(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, double>(
|
||||
iter,
|
||||
// implicitly upcast scalar_t to double
|
||||
func_wrapper<double>([] GPU_LAMBDA(double a, double b) -> double {
|
||||
union {
|
||||
double d;
|
||||
uint64_t u;
|
||||
} a_converter, b_converter, result_converter;
|
||||
|
||||
a_converter.d = a;
|
||||
b_converter.d = b;
|
||||
result_converter.u = a_converter.u ^ b_converter.u;
|
||||
// return a double, otherwise uint64_t will be cast to double
|
||||
// when accumulating and the result will be wrong
|
||||
return result_converter.d;
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
struct xor_sum_functor<scalar_t, std::enable_if_t<std::is_same_v<scalar_t, bool>>> {
|
||||
void operator()(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<bool, uint64_t>(
|
||||
iter, func_wrapper<uint64_t>([] GPU_LAMBDA(bool a, bool b) -> uint64_t {
|
||||
// Bitcast to uint64_t after the XOR operation (using != for booleans)
|
||||
return static_cast<uint64_t>(a != b);
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
// The function `reduce_dispatch` below dispatches to the kernel based
|
||||
// on the type of `iter`. It takes care of the common logic
|
||||
// for handling Half-Precision floating types.
|
||||
@ -222,8 +267,17 @@ static void prod_kernel_cuda(TensorIterator& iter) {
|
||||
reduce_dispatch<prod_functor>(iter, general_dispatcher);
|
||||
}
|
||||
|
||||
static void xor_sum_kernel_cuda(TensorIterator& iter) {
|
||||
// Use iter.dtype(1) to dispatch based on the type of the input tensor
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kHalf, kBFloat16, kBool, iter.dtype(1), "xor_sum_cuda", [&]() {
|
||||
xor_sum_functor<scalar_t>{}(iter);
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda)
|
||||
REGISTER_DISPATCH(nansum_stub, &nansum_kernel_cuda)
|
||||
REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda)
|
||||
REGISTER_DISPATCH(xor_sum_stub, &xor_sum_kernel_cuda)
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -5869,6 +5869,15 @@
|
||||
CPU, CUDA: nansum_out
|
||||
MPS: nansum_out_mps
|
||||
|
||||
- func: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor
|
||||
variants: function, method
|
||||
structured_delegate: hash_tensor.out
|
||||
|
||||
- func: hash_tensor.out(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
dispatch:
|
||||
CPU, CUDA: hash_tensor_out
|
||||
|
||||
- func: sum_to_size(Tensor self, SymInt[] size) -> Tensor
|
||||
variants: method
|
||||
device_check: NoCheck
|
||||
|
@ -475,6 +475,7 @@ Reduction Ops
|
||||
var
|
||||
var_mean
|
||||
count_nonzero
|
||||
hash_tensor
|
||||
|
||||
Comparison Ops
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -852,6 +852,8 @@ aten::hann_window.periodic
|
||||
aten::hann_window.periodic_out
|
||||
aten::hardshrink_backward
|
||||
aten::hardshrink_backward.grad_input
|
||||
aten::hash_tensor
|
||||
aten::hash_tensor.out
|
||||
aten::histc
|
||||
aten::histc.out
|
||||
aten::histogram.bin_ct
|
||||
|
@ -8088,6 +8088,7 @@ FORWARD_SKIPS_AND_XFAILS = [
|
||||
"std.unbiased",
|
||||
"var",
|
||||
"var.unbiased",
|
||||
"hash_tensor",
|
||||
},
|
||||
name="not_implemented",
|
||||
),
|
||||
|
@ -1799,6 +1799,9 @@
|
||||
self: zeros_like(grad)
|
||||
result: auto_element_wise
|
||||
|
||||
- name: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor
|
||||
output_differentiability: [False]
|
||||
|
||||
# DO NOT define a backward for to_dense
|
||||
# See [Note: Sometimes view derivatives]
|
||||
# - name: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
|
||||
|
@ -1958,6 +1958,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch.hamming_window",
|
||||
"torch.hann_window",
|
||||
"torch.hardshrink",
|
||||
"torch.hash_tensor",
|
||||
"torch.heaviside",
|
||||
"torch.hinge_embedding_loss",
|
||||
"torch.histc",
|
||||
|
@ -5018,6 +5018,60 @@ Alias for :func:`torch.gt`.
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.hash_tensor,
|
||||
r"""
|
||||
hash_tensor(input, *, mode=0) -> Tensor
|
||||
|
||||
Returns a hash of all elements in the :attr:`input` tensor.
|
||||
|
||||
Currently only mode=0 (reduction via xor) is supported. The output will always
|
||||
be of type ``torch.uint64``. The elements of ``input`` are upcasted to their
|
||||
64 bit float / integer equivalent and bitcasted to ``torch.uint64`` before
|
||||
reduction via xor.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword Args:
|
||||
mode (int) : The hash to use. Default: 0 (xor_reduction)
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.randn(1, 3)
|
||||
>>> a
|
||||
tensor([[ 1.1918, -1.1813, 0.3373]])
|
||||
>>> torch.hash_tensor(a)
|
||||
tensor(13822780554648485888, dtype=torch.uint64)
|
||||
|
||||
.. function:: hash_tensor(input, dim, *, keepdim=False, mode=0) -> Tensor
|
||||
:noindex:
|
||||
|
||||
Returns the hash of each row of the :attr:`input` tensor in the given
|
||||
dimension :attr:`dim` given by mode. If :attr:`dim` is a list of dimensions,
|
||||
reduce over all of them.
|
||||
|
||||
{keepdim_details}
|
||||
|
||||
Args:
|
||||
{input}
|
||||
{opt_dim_all_reduce}
|
||||
{opt_keepdim}
|
||||
|
||||
Keyword Args:
|
||||
mode (int) : The hash to use. Default: 0 (xor_reduction)
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.randn(2, 4)
|
||||
>>> a
|
||||
tensor([[ 0.1317, -0.5554, -1.4724, -1.1391],
|
||||
[ 0.0778, -0.6070, 0.6375, 0.1798]])
|
||||
>>> torch.hash_tensor(a, 1)
|
||||
tensor([9233691267014066176, 9255993250844508160], dtype=torch.uint64)
|
||||
""".format(**multi_dim_common),
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.histc,
|
||||
r"""
|
||||
|
@ -675,6 +675,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
|
||||
torch.gt: lambda input, other, out=None: -1,
|
||||
torch.greater: lambda input, other, out=None: -1,
|
||||
torch.hardshrink: lambda input, lambd=0.5: -1,
|
||||
torch.hash_tensor: lambda input, dim=None, keepdim=False, mode=0, out=None: -1,
|
||||
torch.heaviside: lambda input, values, out=None: -1,
|
||||
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
|
||||
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
|
||||
|
@ -11595,6 +11595,26 @@ def reference_searchsorted(sorted_sequence, boundary, out_int32=False, right=Fal
|
||||
split_ret = [i.astype(np.int32) for i in split_ret] if out_int32 else split_ret
|
||||
return np.stack(split_ret).reshape(orig_shape)
|
||||
|
||||
def reference_hash_tensor(tensor, dim=(), keepdim=False, mode=0):
|
||||
assert mode == 0, "Only mode=0 (xor_sum) is supported right now"
|
||||
|
||||
dtype = tensor.dtype
|
||||
if dtype.kind == 'f':
|
||||
tensor = tensor.astype(np.float64).view(np.uint64)
|
||||
else:
|
||||
tensor = tensor.astype(np.uint64)
|
||||
|
||||
|
||||
if dim == ():
|
||||
result = np.bitwise_xor.reduce(tensor.flatten(), keepdims=keepdim)
|
||||
else:
|
||||
if isinstance(dim, list):
|
||||
dim = tuple(dim)
|
||||
result = np.bitwise_xor.reduce(tensor, axis=dim, keepdims=keepdim)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def loss_reference_reduction_wrapper(fn):
|
||||
def wrapper(input, target, *, size_average=None, reduce=None, reduction="mean", **other_kwargs):
|
||||
if size_average is not None or reduce is not None:
|
||||
@ -21362,6 +21382,26 @@ op_db: list[OpInfo] = [
|
||||
"TestConsistency", "test_output_match", device_type="mps"),
|
||||
),
|
||||
),
|
||||
ReductionOpInfo(
|
||||
'hash_tensor',
|
||||
result_dtype=torch.uint64,
|
||||
supports_autograd=False,
|
||||
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
ref=reference_hash_tensor,
|
||||
skips=(
|
||||
# hash_tensor reduces all dimensions when dim=[] (as do sum, prod etc.)
|
||||
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
|
||||
# aten::hash_tensor hit the vmap fallback which is currently disabled
|
||||
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
|
||||
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
|
||||
# NYI
|
||||
DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
|
||||
# Sharding strategy NYI
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
|
||||
)
|
||||
),
|
||||
OpInfo(
|
||||
"nn.functional.ctc_loss",
|
||||
dtypes=floating_types(),
|
||||
|
@ -390,6 +390,7 @@ if torch.backends.mps.is_available():
|
||||
"gcd": None,
|
||||
"geqrf": None,
|
||||
"nn.functional.grid_sample": None, # Unsupported Border padding mode
|
||||
"hash_tensor": None,
|
||||
"heaviside": None,
|
||||
"igamma": None,
|
||||
"igammac": None,
|
||||
|
@ -107,6 +107,7 @@ extra_op_data = {
|
||||
"flatten": ExtraOpData(is_view=True, dim_args=[["start_dim", "end_dim"]]),
|
||||
"flip": ExtraOpData(dim_args=[["dims..."]]),
|
||||
"gather": ExtraOpData(dim_args=[["dim"]]),
|
||||
"hash_tensor": ExtraOpData(dim_args=[["dim..."]]),
|
||||
"imag": ExtraOpData(is_view=True),
|
||||
"index_add": ExtraOpData(dim_args=[["dim"]]),
|
||||
"index_copy": ExtraOpData(dim_args=[["dim"]]),
|
||||
|
Reference in New Issue
Block a user