Add optional check_pinning argument to _validate_sparse_compressed_tensor/coo_args (#154759)

As in the title.

A prerequisite to https://github.com/pytorch/pytorch/pull/154638 .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154759
Approved by: https://github.com/amjames, https://github.com/ngimel
ghstack dependencies: #154610
This commit is contained in:
Pearu Peterson
2025-06-01 08:39:05 +03:00
committed by PyTorch MergeBot
parent 3f3c1f419f
commit ff4515fde5
3 changed files with 36 additions and 29 deletions

View File

@ -7270,13 +7270,13 @@
dispatch:
CompositeImplicitAutograd: _sparse_coo_tensor_unsafe_symint
- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None) -> ()
- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None, bool? check_pinning=None) -> ()
- func: _validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout) -> ()
- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()
- func: _validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> ()
- func: _validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()
- func: _validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> ()
- func: _validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout, bool? check_pinning=None) -> ()
- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
- func: _validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
- func: _validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
- func: _validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
dispatch:

View File

@ -125,7 +125,7 @@ bool solve_arange(const Tensor& input, int64_t& start, int64_t& end, int64_t& st
formats with support to batched and dense dimensions.
*/
static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout) {
static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout, std::optional<bool> check_pinning_) {
// Layout must be Sparse Compressed, 2.4
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", [&]{});
@ -134,6 +134,7 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres
const std::string plain_indices_name = plainIndicesName(layout);
const std::string compressed_dim_name = compressedDimName(layout);
const std::string plain_dim_name = plainDimName(layout);
const bool check_pinning = check_pinning_.value_or(true);
// Layout Invariants
@ -295,20 +296,22 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres
") must match device of ", plain_indices_name, " (=",
plain_indices.device(),
")");
TORCH_CHECK(
if (check_pinning) {
TORCH_CHECK(
compressed_indices.is_pinned() == values.is_pinned(),
"memory pinning of ", compressed_indices_name, " (=",
compressed_indices.is_pinned(),
") must match memory pinning of values (=",
values.is_pinned(),
")");
TORCH_CHECK(
TORCH_CHECK(
compressed_indices.is_pinned() == plain_indices.is_pinned(),
"memory pinning of ", compressed_indices_name, " (=",
compressed_indices.is_pinned(),
") must match memory pinning of ", plain_indices_name, " (=",
plain_indices.is_pinned(),
")");
}
// Autograd Invariants
//
@ -319,24 +322,24 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres
TORCH_INTERNAL_ASSERT(!plain_indices.requires_grad());
}
void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) {
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout);
void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout, std::optional<bool> check_pinning) {
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout, check_pinning);
}
void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr);
void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size, std::optional<bool> check_pinning) {
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr, check_pinning);
}
void _validate_sparse_csc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) {
_validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseCsc);
void _validate_sparse_csc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size, std::optional<bool> check_pinning) {
_validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseCsc, check_pinning);
}
void _validate_sparse_bsr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseBsr);
void _validate_sparse_bsr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size, std::optional<bool> check_pinning) {
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseBsr, check_pinning);
}
void _validate_sparse_bsc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) {
_validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseBsc);
void _validate_sparse_bsc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size, std::optional<bool> check_pinning) {
_validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseBsc, check_pinning);
}
// Construction of CSR, CSC, BSR, and BSC tensors.
@ -467,7 +470,7 @@ Tensor _sparse_compressed_tensor_unsafe_symint(
Layout layout_ = layout.value();
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{});
if (at::globalContext().checkSparseTensorInvariants()) {
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, C10_AS_INTARRAYREF_SLOW(size), layout_);
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, C10_AS_INTARRAYREF_SLOW(size), layout_, true);
}
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_compressed_tensor(options);
@ -491,7 +494,7 @@ static Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed
Layout layout_ = layout.value_or(required_layout);
TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_);
if (at::globalContext().checkSparseTensorInvariants()) {
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_, true);
}
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_compressed_tensor(options);

View File

@ -371,9 +371,11 @@ void _validate_sparse_coo_tensor_args(
const Tensor& indices,
const Tensor& values_,
ArrayRef<int64_t> size,
std::optional<bool> is_coalesced_) {
std::optional<bool> is_coalesced_,
std::optional<bool> check_pinning_) {
Tensor values = expand_values_if_needed(values_);
bool is_coalesced = is_coalesced_.value_or(false);
const bool check_pinning = check_pinning_.value_or(true);
// the following checks are redundant because they are also checked in
// SparseTensorImpl::set_indices_and_values_unsafe but we need to ensure them
@ -397,13 +399,15 @@ void _validate_sparse_coo_tensor_args(
"), but got ",
size.size());
TORCH_CHECK(
indices.is_pinned() == values.is_pinned(),
"memory pinning of indices (=",
indices.is_pinned(),
") must match memory pinning of values (=",
values.is_pinned(),
")");
if (check_pinning) {
TORCH_CHECK(
indices.is_pinned() == values.is_pinned(),
"memory pinning of indices (=",
indices.is_pinned(),
") must match memory pinning of values (=",
values.is_pinned(),
")");
}
// Check to make sure all indices are within the boundaries of `size`
if (indices.numel() > 0) {