mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add spdiags sparse matrix initialization (#78439)"
This reverts commit cfb2034b657e8527767f1f74854bc62b4d6d4927.
Reverted https://github.com/pytorch/pytorch/pull/78439 on behalf of https://github.com/suo due to broke windows builds, see: cfb2034b65
This commit is contained in:
@ -1,74 +0,0 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/SparseTensorImpl.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/TensorIndexing.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/sparse/SparseFactories.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/sparse_coo_tensor.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
using namespace at::sparse;
|
||||
|
||||
namespace {
|
||||
void _spdiags_kernel_cpu(
|
||||
TensorIterator& iter,
|
||||
const Tensor& diagonals,
|
||||
Tensor& values,
|
||||
Tensor& indices) {
|
||||
auto* row_index_write_ptr = indices[0].data_ptr<int64_t>();
|
||||
auto* col_index_write_ptr = indices[1].data_ptr<int64_t>();
|
||||
const int64_t diagonals_read_stride = diagonals.stride(1);
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
at::ScalarType::BFloat16,
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::Bool,
|
||||
at::ScalarType::ComplexHalf,
|
||||
diagonals.scalar_type(),
|
||||
"spdiags_cpu",
|
||||
[&] {
|
||||
auto* values_write_ptr = values.data_ptr<scalar_t>();
|
||||
cpu_kernel(
|
||||
iter,
|
||||
[&](int64_t diag_index,
|
||||
int64_t diag_offset,
|
||||
int64_t out_offset,
|
||||
int64_t n_out) -> int64_t {
|
||||
if (n_out > 0) {
|
||||
auto* rows_start = row_index_write_ptr + out_offset;
|
||||
auto* cols_start = col_index_write_ptr + out_offset;
|
||||
auto* vals_start = values_write_ptr + out_offset;
|
||||
const int64_t first_col = std::max<int64_t>(diag_offset, 0);
|
||||
const int64_t first_row = first_col - diag_offset;
|
||||
auto* data_read = diagonals[diag_index].data_ptr<scalar_t>() +
|
||||
first_col * diagonals_read_stride;
|
||||
for (int64_t i = 0; i < n_out; ++i) {
|
||||
rows_start[i] = first_row + i;
|
||||
cols_start[i] = first_col + i;
|
||||
vals_start[i] = data_read[i * diagonals_read_stride];
|
||||
}
|
||||
}
|
||||
// dummy return
|
||||
return 0;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_DISPATCH(spdiags_kernel_stub, &_spdiags_kernel_cpu)
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -5281,11 +5281,6 @@
|
||||
SparseCPU: log_softmax_backward_sparse_cpu
|
||||
SparseCUDA: log_softmax_backward_sparse_cuda
|
||||
|
||||
- func: _spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor
|
||||
python_module: sparse
|
||||
dispatch:
|
||||
CPU: spdiags
|
||||
|
||||
- func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
|
@ -1,95 +0,0 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/sparse/SparseFactories.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_unique.h>
|
||||
#include <ATen/ops/arange.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/sparse_coo_tensor.h>
|
||||
#include <ATen/ops/where.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
DEFINE_DISPATCH(spdiags_kernel_stub);
|
||||
|
||||
Tensor spdiags(
|
||||
const Tensor& diagonals,
|
||||
const Tensor& offsets,
|
||||
IntArrayRef shape,
|
||||
c10::optional<Layout> layout) {
|
||||
auto diagonals_2d = diagonals.dim() == 1 ? diagonals.unsqueeze(0) : diagonals;
|
||||
TORCH_CHECK(diagonals_2d.dim() == 2, "Diagonals must be vector or matrix");
|
||||
TORCH_CHECK(shape.size() == 2, "Output shape must be 2d");
|
||||
auto offsets_1d = offsets.dim() == 0 ? offsets.unsqueeze(0) : offsets;
|
||||
TORCH_CHECK(offsets_1d.dim() == 1, "Offsets must be scalar or vector");
|
||||
TORCH_CHECK(
|
||||
diagonals_2d.size(0) == offsets_1d.size(0),
|
||||
"Number of diagonals (",
|
||||
diagonals_2d.size(0),
|
||||
") does not match the number of offsets (",
|
||||
offsets_1d.size(0),
|
||||
")");
|
||||
if (layout) {
|
||||
TORCH_CHECK(
|
||||
(*layout == Layout::Sparse) || (*layout == Layout::SparseCsc) ||
|
||||
(*layout == Layout::SparseCsr),
|
||||
"Only output layouts (Sparse, SparseCsc, SparseCsr) are supported, got ",
|
||||
*layout);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
offsets_1d.scalar_type() == at::kLong,
|
||||
"Offset Tensor must have dtype Long but got ",
|
||||
offsets_1d.scalar_type());
|
||||
|
||||
TORCH_CHECK(
|
||||
offsets_1d.numel() == std::get<0>(at::_unique(offsets_1d)).numel(),
|
||||
"Offset tensor contains duplicate values");
|
||||
|
||||
auto nnz_per_diag = at::where(
|
||||
offsets_1d.le(0),
|
||||
offsets_1d.add(shape[0]).clamp_max_(diagonals_2d.size(1)),
|
||||
offsets_1d.add(-std::min<int64_t>(shape[1], diagonals_2d.size(1))).neg());
|
||||
|
||||
auto nnz_per_diag_cumsum = nnz_per_diag.cumsum(-1);
|
||||
const auto nnz = diagonals_2d.size(0) > 0
|
||||
? nnz_per_diag_cumsum.select(-1, -1).item<int64_t>()
|
||||
: int64_t{0};
|
||||
// Offsets into nnz for each diagonal
|
||||
auto result_mem_offsets = nnz_per_diag_cumsum.sub(nnz_per_diag);
|
||||
// coo tensor guts
|
||||
auto indices = at::empty({2, nnz}, offsets_1d.options());
|
||||
auto values = at::empty({nnz}, diagonals_2d.options());
|
||||
// We add this indexer to lookup the row of diagonals we are reading from at
|
||||
// each iteration
|
||||
const auto n_diag = offsets_1d.size(0);
|
||||
Tensor diag_index = at::arange(n_diag, offsets_1d.options());
|
||||
// cpu_kernel requires an output
|
||||
auto dummy = at::empty({1}, offsets_1d.options()).resize_({0});
|
||||
auto iter = TensorIteratorConfig()
|
||||
.set_check_mem_overlap(false)
|
||||
.add_output(dummy)
|
||||
.add_input(diag_index)
|
||||
.add_input(offsets_1d)
|
||||
.add_input(result_mem_offsets)
|
||||
.add_input(nnz_per_diag)
|
||||
.build();
|
||||
spdiags_kernel_stub(iter.device_type(), iter, diagonals_2d, values, indices);
|
||||
auto result_coo = at::sparse_coo_tensor(indices, values, shape);
|
||||
if (layout) {
|
||||
if (*layout == Layout::SparseCsr) {
|
||||
return result_coo.to_sparse_csr();
|
||||
}
|
||||
if (*layout == Layout::SparseCsc) {
|
||||
return result_coo.to_sparse_csc();
|
||||
}
|
||||
}
|
||||
return result_coo;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
#include <ATen/TensorIterator.h>
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
using spdiags_kernel_fn_t =
|
||||
void (*)(TensorIterator&, const Tensor&, Tensor&, Tensor&);
|
||||
|
||||
DECLARE_DISPATCH(spdiags_kernel_fn_t, spdiags_kernel_stub);
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -1155,7 +1155,6 @@ aten_native_source_codegen_list = [
|
||||
"aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp",
|
||||
"aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",
|
||||
"aten/src/ATen/native/cpu/spherical_bessel_j0.cpp",
|
||||
"aten/src/ATen/native/cpu/SparseFactories.cpp",
|
||||
"aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp",
|
||||
]
|
||||
|
||||
@ -1358,7 +1357,6 @@ aten_native_source_non_codegen_list = [
|
||||
"aten/src/ATen/native/sparse/SparseTensorMath.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseUnaryOps.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseFactories.cpp",
|
||||
"aten/src/ATen/native/transformers/attention.cpp",
|
||||
"aten/src/ATen/native/transformers/transformer.cpp",
|
||||
"aten/src/ATen/native/utils/Factory.cpp",
|
||||
|
@ -599,7 +599,6 @@ Torch functions specific to sparse Tensors
|
||||
smm
|
||||
sparse.softmax
|
||||
sparse.log_softmax
|
||||
sparse.spdiags
|
||||
|
||||
Other functions
|
||||
+++++++++++++++
|
||||
|
@ -8,7 +8,7 @@ import random
|
||||
import unittest
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
|
||||
do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
|
||||
do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS, gradcheck, coalescedonoff, \
|
||||
DeterministicGuard, first_sample
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
|
||||
from numbers import Number
|
||||
@ -26,9 +26,6 @@ from torch.testing._internal.common_dtype import (
|
||||
floating_and_complex_types_and, integral_types, floating_types_and,
|
||||
)
|
||||
|
||||
if TEST_SCIPY:
|
||||
import scipy.sparse
|
||||
|
||||
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
@ -3561,94 +3558,6 @@ class TestSparse(TestCase):
|
||||
test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3])
|
||||
test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3])
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy is not availible")
|
||||
@onlyCPU
|
||||
@dtypes(*all_types_and_complex_and(torch.bool))
|
||||
def test_sparse_spdiags(self, device, dtype):
|
||||
|
||||
make_diags = functools.partial(make_tensor, dtype=dtype, device=device)
|
||||
make_offsets = functools.partial(torch.tensor, dtype=torch.long, device=device)
|
||||
|
||||
if TEST_SCIPY:
|
||||
def reference(diags, offsets, shape):
|
||||
return scipy.sparse.spdiags(diags, offsets, *shape).toarray()
|
||||
|
||||
else:
|
||||
def reference(diags, offsets, shape):
|
||||
result = torch.zeros(shape, dtype=dtype, device=device)
|
||||
for i, off in enumerate(offsets):
|
||||
res_view = result.diagonal(off)
|
||||
data = diags[i]
|
||||
if off > 0:
|
||||
data = data[off:]
|
||||
|
||||
m = min(res_view.shape[0], data.shape[0])
|
||||
res_view[:m] = data[:m]
|
||||
return result
|
||||
|
||||
def check_valid(diags, offsets, shape, layout=None):
|
||||
ref_out = reference(diags, offsets, shape)
|
||||
out = torch.sparse.spdiags(diags, offsets, shape, layout=layout)
|
||||
if layout is None:
|
||||
ex_layout = torch.sparse_coo
|
||||
else:
|
||||
ex_layout = layout
|
||||
out_dense = out.to_dense()
|
||||
self.assertTrue(out.layout == ex_layout, f"Output layout {out.layout} expected {ex_layout}")
|
||||
self.assertEqual(out_dense, ref_out, f"Result:\n{out_dense} does not match reference:\n{ref_out}")
|
||||
|
||||
def check_invalid(args, error):
|
||||
with self.assertRaisesRegex(RuntimeError, error):
|
||||
torch.sparse.spdiags(*args)
|
||||
|
||||
def valid_cases():
|
||||
# some normal cases
|
||||
yield (make_diags((1, 5)), make_offsets([0]), (5, 5))
|
||||
yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4))
|
||||
# noncontigous diags
|
||||
yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5))
|
||||
# noncontigous offsets
|
||||
yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
|
||||
# noncontigous diags + offsets
|
||||
yield (make_diags((3, 4), noncontiguous=True), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
|
||||
# correct dimensionality, 2d, 2d , and shapes match, but the number of diagonals is zero
|
||||
yield (make_diags((0, 3)), make_offsets([]), (3, 3))
|
||||
# forward rotation of upper diagonals
|
||||
yield (make_diags((3, 8)), make_offsets([1, 2, 3]), (4, 4))
|
||||
# rotation exausts input space to read from
|
||||
yield (make_diags((2, 3)), make_offsets([2, 1]), (3, 3))
|
||||
# Simple cases repeated with special output format
|
||||
yield (make_diags((1, 5)), make_offsets([0]), (5, 5), torch.sparse_csc)
|
||||
yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4), torch.sparse_csr)
|
||||
# vector diags
|
||||
yield (make_diags((3, )), make_offsets([1]), (4, 4))
|
||||
# Scalar offset
|
||||
yield (make_diags((1, 3)), make_offsets(2), (4, 4))
|
||||
# offsets out of range
|
||||
yield (make_diags((1, 3)), make_offsets([3]), (3, 3))
|
||||
yield (make_diags((1, 3)), make_offsets([-3]), (3, 3))
|
||||
|
||||
for case in valid_cases():
|
||||
check_valid(*case)
|
||||
|
||||
def invalid_cases():
|
||||
yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d"
|
||||
yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector"
|
||||
yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix"
|
||||
yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)),\
|
||||
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
|
||||
yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)),\
|
||||
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
|
||||
yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided),\
|
||||
r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+"
|
||||
yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values"
|
||||
yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+"
|
||||
|
||||
|
||||
for case, error_regex in invalid_cases():
|
||||
check_invalid(case, error_regex)
|
||||
|
||||
|
||||
|
||||
class TestSparseOneOff(TestCase):
|
||||
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
|
||||
|
@ -262,97 +262,3 @@ Args:
|
||||
performed. This is useful for preventing data type
|
||||
overflows. Default: None
|
||||
""")
|
||||
|
||||
|
||||
spdiags = _add_docstr(
|
||||
_sparse._spdiags,
|
||||
r"""
|
||||
sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor
|
||||
|
||||
Creates a sparse 2D tensor by placing the values from rows of
|
||||
:attr:`diagonals` along specified diagonals of the output
|
||||
|
||||
The :attr:`offsets` tensor controls which diagonals are set.
|
||||
|
||||
- If :attr:`offsets[i]` = 0, it is the main diagonal
|
||||
- If :attr:`offsets[i]` < 0, it is below the main diagonal
|
||||
- If :attr:`offsets[i]` > 0, it is above the main diagonal
|
||||
|
||||
The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`,
|
||||
and an offset may not be repeated.
|
||||
|
||||
Args:
|
||||
diagonals (Tensor): Matrix storing diagonals row-wise
|
||||
offsets (Tensor): The diagonals to be set, stored as a vector
|
||||
shape (2-tuple of ints): The desired shape of the result
|
||||
Keyword args:
|
||||
layout (:class:`torch.layout`, optional): The desired layout of the
|
||||
returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr``
|
||||
are supported. Default: ``torch.sparse_coo``
|
||||
|
||||
Examples:
|
||||
|
||||
Set the main and first two lower diagonals of a matrix::
|
||||
|
||||
>>> diags = torch.arange(9).reshape(3, 3)
|
||||
>>> diags
|
||||
tensor([[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8]])
|
||||
>>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3))
|
||||
>>> s
|
||||
tensor(indices=tensor([[0, 1, 2, 1, 2, 2],
|
||||
[0, 1, 2, 0, 1, 0]]),
|
||||
values=tensor([0, 1, 2, 3, 4, 6]),
|
||||
size=(3, 3), nnz=6, layout=torch.sparse_coo)
|
||||
>>> s.to_dense()
|
||||
tensor([[0, 0, 0],
|
||||
[3, 1, 0],
|
||||
[6, 4, 2]])
|
||||
|
||||
|
||||
Change the output layout::
|
||||
|
||||
>>> diags = torch.arange(9).reshape(3, 3)
|
||||
>>> diags
|
||||
tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8])
|
||||
>>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr)
|
||||
>>> s
|
||||
tensor(crow_indices=tensor([0, 1, 3, 6]),
|
||||
col_indices=tensor([0, 0, 1, 0, 1, 2]),
|
||||
values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6,
|
||||
layout=torch.sparse_csr)
|
||||
>>> s.to_dense()
|
||||
tensor([[0, 0, 0],
|
||||
[3, 1, 0],
|
||||
[6, 4, 2]])
|
||||
|
||||
Set partial diagonals of a large output::
|
||||
|
||||
>>> diags = torch.tensor([[1, 2], [3, 4]])
|
||||
>>> offsets = torch.tensor([0, -1])
|
||||
>>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense()
|
||||
tensor([[1, 0, 0, 0, 0],
|
||||
[3, 2, 0, 0, 0],
|
||||
[0, 4, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]])
|
||||
|
||||
.. note::
|
||||
|
||||
When setting the values along a given diagonal the index into the diagonal
|
||||
and the index into the row of :attr:`diagonals` is taken as the
|
||||
column index in the output. This has the effect that when setting a diagonal
|
||||
with a positive offset `k` the first value along that diagonal will be
|
||||
the value in position `k` of the row of :attr:`diagonals`
|
||||
|
||||
Specifying a positive offset::
|
||||
|
||||
>>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
|
||||
>>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense()
|
||||
tensor([[1, 2, 3, 0, 0],
|
||||
[0, 2, 3, 0, 0],
|
||||
[0, 0, 3, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]])
|
||||
""")
|
||||
|
Reference in New Issue
Block a user