Add basic torch.hash_tensor op (#154149)

Added `torch.hash_tensor` reduction function with a `mode` argument that defaults to reduction with xor.

- The hash is always uint64.
- Integers will be casted to uint64 before performing the xor_sum reduction
- Floats will be upcasted to double and then bitcasted to uint64 before performing the xor_sum reduction

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154149
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2025-07-23 08:23:11 -07:00
committed by PyTorch MergeBot
parent 86df3ff1f1
commit 7f649ed4f8
15 changed files with 247 additions and 0 deletions

View File

@ -71,6 +71,8 @@
#include <ATen/ops/exp.h>
#include <ATen/ops/gather.h>
#include <ATen/ops/gradient_native.h>
#include <ATen/ops/hash_tensor.h>
#include <ATen/ops/hash_tensor_native.h>
#include <ATen/ops/imag.h>
#include <ATen/ops/isnan_native.h>
#include <ATen/ops/linalg_vector_norm.h>
@ -398,6 +400,19 @@ TORCH_META_FUNC(amin)
resize_reduction(*this, self, dim, keepdim, out_dtype);
}
TORCH_META_FUNC(hash_tensor)
(const Tensor& self, IntArrayRef dim, bool keepdim, int64_t mode) {
auto maybe_result = maybe_get_output();
if (maybe_result.defined()){
TORCH_CHECK(maybe_result.scalar_type() == at::kUInt64, "Expected result to be of dtype uint64, but got ", maybe_result.scalar_type());
}
if (self.sym_numel() == 0) {
native::zero_numel_check_dims(self, dim, "hash_tensor");
}
resize_reduction(*this, self, dim, keepdim, at::kUInt64);
}
} // namespace at::meta
namespace at::native {
@ -441,6 +456,7 @@ DEFINE_DISPATCH(argmin_stub);
DEFINE_DISPATCH(cumsum_stub);
DEFINE_DISPATCH(cumprod_stub);
DEFINE_DISPATCH(logcumsumexp_stub);
DEFINE_DISPATCH(xor_sum_stub);
Tensor _logcumsumexp_cpu(const Tensor& self, int64_t dim) {
Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
@ -2233,6 +2249,24 @@ Tensor dist(const Tensor &self, const Tensor& other, const Scalar& p){
return at::norm(self - other, p);
}
enum class HashMode { XOR_SUM = 0 };
TORCH_IMPL_FUNC(hash_tensor_out) (const Tensor& self, IntArrayRef dim, bool keepdim, int64_t mode, const Tensor& result) {
auto iter = meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
switch (static_cast<HashMode>(mode)) {
case HashMode::XOR_SUM:
if (iter.numel() == 0) {
result.fill_(0);
} else {
xor_sum_stub(iter.device_type(), iter);
}
return;
default:
TORCH_CHECK(false, "Unknown hash_tensor mode: ", mode);
}
}
bool cpu_equal(const Tensor& self, const Tensor& other) {
if (!at::namedinference::are_names_equal(
self.unsafeGetTensorImpl(), other.unsafeGetTensorImpl())) {

View File

@ -27,6 +27,7 @@ DECLARE_DISPATCH(reduce_fn, min_values_stub)
DECLARE_DISPATCH(reduce_fn, max_values_stub)
DECLARE_DISPATCH(reduce_fn, argmax_stub)
DECLARE_DISPATCH(reduce_fn, argmin_stub)
DECLARE_DISPATCH(reduce_fn, xor_sum_stub)
using reduce_std_var_function =
void (*)(TensorIterator&, double correction, bool take_sqrt);

View File

@ -425,6 +425,49 @@ static void argmin_kernel_impl(TensorIterator &iter) {
});
}
template <typename scalar_t, typename acc_t = uint64_t, typename out_t = acc_t>
struct XorSumOps {
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
if (std::is_same<scalar_t, bool>::value) {
return acc ^ (data ? 1 : 0);
} else if (
std::is_same<scalar_t, float>::value ||
std::is_same<scalar_t, double>::value ||
std::is_same<scalar_t, at::BFloat16>::value ||
std::is_same<scalar_t, at::Half>::value) {
union {
double d;
uint64_t u;
} converter;
converter.d = static_cast<double>(data);
return acc ^ converter.u;
} else {
return acc ^ static_cast<uint64_t>(data);
}
}
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
return a ^ b;
}
inline C10_DEVICE out_t project(acc_t a) const {
return a;
}
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
return acc;
}
};
static void xor_sum_kernel_impl(TensorIterator& iter) {
// Use iter.dtype(1) to dispatch based on the type of the input tensor
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool, iter.dtype(1), "xor_sum_cpu", [&] {
binary_kernel_reduce(
iter, XorSumOps<scalar_t>(), static_cast<uint64_t>(0));
});
}
} // anonymous namespace
REGISTER_DISPATCH(std_var_stub, &std_var_kernel_impl)
@ -439,6 +482,7 @@ REGISTER_DISPATCH(min_values_stub, &min_values_kernel_impl)
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_impl)
REGISTER_DISPATCH(argmax_stub, &argmax_kernel_impl)
REGISTER_DISPATCH(argmin_stub, &argmin_kernel_impl)
REGISTER_DISPATCH(xor_sum_stub, &xor_sum_kernel_impl)
REGISTER_DISPATCH(cumprod_stub, &cumprod_cpu_kernel)
REGISTER_DISPATCH(cumsum_stub, &cumsum_cpu_kernel)

View File

@ -154,6 +154,51 @@ struct prod_functor<c10::complex<at::Half>> {
#endif
};
template <typename scalar_t, typename enable = void>
struct xor_sum_functor {
void operator()(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, uint64_t>(
iter,
func_wrapper<uint64_t>(
[] GPU_LAMBDA(uint64_t a, uint64_t b) -> uint64_t {
return a ^ b;
}));
}
};
template <typename scalar_t>
struct xor_sum_functor<scalar_t, std::enable_if_t<!std::is_integral_v<scalar_t>>> {
void operator()(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, double>(
iter,
// implicitly upcast scalar_t to double
func_wrapper<double>([] GPU_LAMBDA(double a, double b) -> double {
union {
double d;
uint64_t u;
} a_converter, b_converter, result_converter;
a_converter.d = a;
b_converter.d = b;
result_converter.u = a_converter.u ^ b_converter.u;
// return a double, otherwise uint64_t will be cast to double
// when accumulating and the result will be wrong
return result_converter.d;
}));
}
};
template <typename scalar_t>
struct xor_sum_functor<scalar_t, std::enable_if_t<std::is_same_v<scalar_t, bool>>> {
void operator()(TensorIterator& iter) {
gpu_reduce_kernel<bool, uint64_t>(
iter, func_wrapper<uint64_t>([] GPU_LAMBDA(bool a, bool b) -> uint64_t {
// Bitcast to uint64_t after the XOR operation (using != for booleans)
return static_cast<uint64_t>(a != b);
}));
}
};
// The function `reduce_dispatch` below dispatches to the kernel based
// on the type of `iter`. It takes care of the common logic
// for handling Half-Precision floating types.
@ -222,8 +267,17 @@ static void prod_kernel_cuda(TensorIterator& iter) {
reduce_dispatch<prod_functor>(iter, general_dispatcher);
}
static void xor_sum_kernel_cuda(TensorIterator& iter) {
// Use iter.dtype(1) to dispatch based on the type of the input tensor
AT_DISPATCH_ALL_TYPES_AND3(
kHalf, kBFloat16, kBool, iter.dtype(1), "xor_sum_cuda", [&]() {
xor_sum_functor<scalar_t>{}(iter);
});
}
REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda)
REGISTER_DISPATCH(nansum_stub, &nansum_kernel_cuda)
REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda)
REGISTER_DISPATCH(xor_sum_stub, &xor_sum_kernel_cuda)
} // namespace at::native

View File

@ -5869,6 +5869,15 @@
CPU, CUDA: nansum_out
MPS: nansum_out_mps
- func: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor
variants: function, method
structured_delegate: hash_tensor.out
- func: hash_tensor.out(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA: hash_tensor_out
- func: sum_to_size(Tensor self, SymInt[] size) -> Tensor
variants: method
device_check: NoCheck

View File

@ -475,6 +475,7 @@ Reduction Ops
var
var_mean
count_nonzero
hash_tensor
Comparison Ops
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -852,6 +852,8 @@ aten::hann_window.periodic
aten::hann_window.periodic_out
aten::hardshrink_backward
aten::hardshrink_backward.grad_input
aten::hash_tensor
aten::hash_tensor.out
aten::histc
aten::histc.out
aten::histogram.bin_ct

View File

@ -8088,6 +8088,7 @@ FORWARD_SKIPS_AND_XFAILS = [
"std.unbiased",
"var",
"var.unbiased",
"hash_tensor",
},
name="not_implemented",
),

View File

@ -1799,6 +1799,9 @@
self: zeros_like(grad)
result: auto_element_wise
- name: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor
output_differentiability: [False]
# DO NOT define a backward for to_dense
# See [Note: Sometimes view derivatives]
# - name: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor

View File

@ -1958,6 +1958,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch.hamming_window",
"torch.hann_window",
"torch.hardshrink",
"torch.hash_tensor",
"torch.heaviside",
"torch.hinge_embedding_loss",
"torch.histc",

View File

@ -5018,6 +5018,60 @@ Alias for :func:`torch.gt`.
""",
)
add_docstr(
torch.hash_tensor,
r"""
hash_tensor(input, *, mode=0) -> Tensor
Returns a hash of all elements in the :attr:`input` tensor.
Currently only mode=0 (reduction via xor) is supported. The output will always
be of type ``torch.uint64``. The elements of ``input`` are upcasted to their
64 bit float / integer equivalent and bitcasted to ``torch.uint64`` before
reduction via xor.
Args:
{input}
Keyword Args:
mode (int) : The hash to use. Default: 0 (xor_reduction)
Example::
>>> a = torch.randn(1, 3)
>>> a
tensor([[ 1.1918, -1.1813, 0.3373]])
>>> torch.hash_tensor(a)
tensor(13822780554648485888, dtype=torch.uint64)
.. function:: hash_tensor(input, dim, *, keepdim=False, mode=0) -> Tensor
:noindex:
Returns the hash of each row of the :attr:`input` tensor in the given
dimension :attr:`dim` given by mode. If :attr:`dim` is a list of dimensions,
reduce over all of them.
{keepdim_details}
Args:
{input}
{opt_dim_all_reduce}
{opt_keepdim}
Keyword Args:
mode (int) : The hash to use. Default: 0 (xor_reduction)
Example::
>>> a = torch.randn(2, 4)
>>> a
tensor([[ 0.1317, -0.5554, -1.4724, -1.1391],
[ 0.0778, -0.6070, 0.6375, 0.1798]])
>>> torch.hash_tensor(a, 1)
tensor([9233691267014066176, 9255993250844508160], dtype=torch.uint64)
""".format(**multi_dim_common),
)
add_docstr(
torch.histc,
r"""

View File

@ -675,6 +675,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
torch.gt: lambda input, other, out=None: -1,
torch.greater: lambda input, other, out=None: -1,
torch.hardshrink: lambda input, lambd=0.5: -1,
torch.hash_tensor: lambda input, dim=None, keepdim=False, mode=0, out=None: -1,
torch.heaviside: lambda input, values, out=None: -1,
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,

View File

@ -11595,6 +11595,26 @@ def reference_searchsorted(sorted_sequence, boundary, out_int32=False, right=Fal
split_ret = [i.astype(np.int32) for i in split_ret] if out_int32 else split_ret
return np.stack(split_ret).reshape(orig_shape)
def reference_hash_tensor(tensor, dim=(), keepdim=False, mode=0):
assert mode == 0, "Only mode=0 (xor_sum) is supported right now"
dtype = tensor.dtype
if dtype.kind == 'f':
tensor = tensor.astype(np.float64).view(np.uint64)
else:
tensor = tensor.astype(np.uint64)
if dim == ():
result = np.bitwise_xor.reduce(tensor.flatten(), keepdims=keepdim)
else:
if isinstance(dim, list):
dim = tuple(dim)
result = np.bitwise_xor.reduce(tensor, axis=dim, keepdims=keepdim)
return result
def loss_reference_reduction_wrapper(fn):
def wrapper(input, target, *, size_average=None, reduce=None, reduction="mean", **other_kwargs):
if size_average is not None or reduce is not None:
@ -21362,6 +21382,26 @@ op_db: list[OpInfo] = [
"TestConsistency", "test_output_match", device_type="mps"),
),
),
ReductionOpInfo(
'hash_tensor',
result_dtype=torch.uint64,
supports_autograd=False,
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
ref=reference_hash_tensor,
skips=(
# hash_tensor reduces all dimensions when dim=[] (as do sum, prod etc.)
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
# aten::hash_tensor hit the vmap fallback which is currently disabled
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
# NYI
DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
# Sharding strategy NYI
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
)
),
OpInfo(
"nn.functional.ctc_loss",
dtypes=floating_types(),

View File

@ -390,6 +390,7 @@ if torch.backends.mps.is_available():
"gcd": None,
"geqrf": None,
"nn.functional.grid_sample": None, # Unsupported Border padding mode
"hash_tensor": None,
"heaviside": None,
"igamma": None,
"igammac": None,

View File

@ -107,6 +107,7 @@ extra_op_data = {
"flatten": ExtraOpData(is_view=True, dim_args=[["start_dim", "end_dim"]]),
"flip": ExtraOpData(dim_args=[["dims..."]]),
"gather": ExtraOpData(dim_args=[["dim"]]),
"hash_tensor": ExtraOpData(dim_args=[["dim..."]]),
"imag": ExtraOpData(is_view=True),
"index_add": ExtraOpData(dim_args=[["dim"]]),
"index_copy": ExtraOpData(dim_args=[["dim"]]),