mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add optional check_pinning argument to _validate_sparse_compressed_tensor/coo_args (#154759)
As in the title. A prerequisite to https://github.com/pytorch/pytorch/pull/154638 . Pull Request resolved: https://github.com/pytorch/pytorch/pull/154759 Approved by: https://github.com/amjames, https://github.com/ngimel ghstack dependencies: #154610
This commit is contained in:
committed by
PyTorch MergeBot
parent
3f3c1f419f
commit
ff4515fde5
@ -7270,13 +7270,13 @@
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: _sparse_coo_tensor_unsafe_symint
|
||||
|
||||
- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None) -> ()
|
||||
- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None, bool? check_pinning=None) -> ()
|
||||
|
||||
- func: _validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout) -> ()
|
||||
- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()
|
||||
- func: _validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> ()
|
||||
- func: _validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()
|
||||
- func: _validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> ()
|
||||
- func: _validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout, bool? check_pinning=None) -> ()
|
||||
- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
|
||||
- func: _validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
|
||||
- func: _validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
|
||||
- func: _validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
|
||||
|
||||
- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
dispatch:
|
||||
|
@ -125,7 +125,7 @@ bool solve_arange(const Tensor& input, int64_t& start, int64_t& end, int64_t& st
|
||||
formats with support to batched and dense dimensions.
|
||||
*/
|
||||
|
||||
static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout) {
|
||||
static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout, std::optional<bool> check_pinning_) {
|
||||
// Layout must be Sparse Compressed, 2.4
|
||||
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", [&]{});
|
||||
|
||||
@ -134,6 +134,7 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres
|
||||
const std::string plain_indices_name = plainIndicesName(layout);
|
||||
const std::string compressed_dim_name = compressedDimName(layout);
|
||||
const std::string plain_dim_name = plainDimName(layout);
|
||||
const bool check_pinning = check_pinning_.value_or(true);
|
||||
|
||||
// Layout Invariants
|
||||
|
||||
@ -295,20 +296,22 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres
|
||||
") must match device of ", plain_indices_name, " (=",
|
||||
plain_indices.device(),
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
if (check_pinning) {
|
||||
TORCH_CHECK(
|
||||
compressed_indices.is_pinned() == values.is_pinned(),
|
||||
"memory pinning of ", compressed_indices_name, " (=",
|
||||
compressed_indices.is_pinned(),
|
||||
") must match memory pinning of values (=",
|
||||
values.is_pinned(),
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK(
|
||||
compressed_indices.is_pinned() == plain_indices.is_pinned(),
|
||||
"memory pinning of ", compressed_indices_name, " (=",
|
||||
compressed_indices.is_pinned(),
|
||||
") must match memory pinning of ", plain_indices_name, " (=",
|
||||
plain_indices.is_pinned(),
|
||||
")");
|
||||
}
|
||||
|
||||
// Autograd Invariants
|
||||
//
|
||||
@ -319,24 +322,24 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres
|
||||
TORCH_INTERNAL_ASSERT(!plain_indices.requires_grad());
|
||||
}
|
||||
|
||||
void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) {
|
||||
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout);
|
||||
void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout, std::optional<bool> check_pinning) {
|
||||
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout, check_pinning);
|
||||
}
|
||||
|
||||
void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
|
||||
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr);
|
||||
void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size, std::optional<bool> check_pinning) {
|
||||
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr, check_pinning);
|
||||
}
|
||||
|
||||
void _validate_sparse_csc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) {
|
||||
_validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseCsc);
|
||||
void _validate_sparse_csc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size, std::optional<bool> check_pinning) {
|
||||
_validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseCsc, check_pinning);
|
||||
}
|
||||
|
||||
void _validate_sparse_bsr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
|
||||
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseBsr);
|
||||
void _validate_sparse_bsr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size, std::optional<bool> check_pinning) {
|
||||
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseBsr, check_pinning);
|
||||
}
|
||||
|
||||
void _validate_sparse_bsc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size) {
|
||||
_validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseBsc);
|
||||
void _validate_sparse_bsc_tensor_args(const Tensor& ccol_indices, const Tensor& row_indices, const Tensor& values, IntArrayRef size, std::optional<bool> check_pinning) {
|
||||
_validate_sparse_compressed_tensor_args_worker(ccol_indices, row_indices, values, size, kSparseBsc, check_pinning);
|
||||
}
|
||||
|
||||
// Construction of CSR, CSC, BSR, and BSC tensors.
|
||||
@ -467,7 +470,7 @@ Tensor _sparse_compressed_tensor_unsafe_symint(
|
||||
Layout layout_ = layout.value();
|
||||
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{});
|
||||
if (at::globalContext().checkSparseTensorInvariants()) {
|
||||
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, C10_AS_INTARRAYREF_SLOW(size), layout_);
|
||||
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, C10_AS_INTARRAYREF_SLOW(size), layout_, true);
|
||||
}
|
||||
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
|
||||
SparseCsrTensor self = new_compressed_tensor(options);
|
||||
@ -491,7 +494,7 @@ static Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed
|
||||
Layout layout_ = layout.value_or(required_layout);
|
||||
TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_);
|
||||
if (at::globalContext().checkSparseTensorInvariants()) {
|
||||
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
|
||||
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_, true);
|
||||
}
|
||||
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
|
||||
SparseCsrTensor self = new_compressed_tensor(options);
|
||||
|
@ -371,9 +371,11 @@ void _validate_sparse_coo_tensor_args(
|
||||
const Tensor& indices,
|
||||
const Tensor& values_,
|
||||
ArrayRef<int64_t> size,
|
||||
std::optional<bool> is_coalesced_) {
|
||||
std::optional<bool> is_coalesced_,
|
||||
std::optional<bool> check_pinning_) {
|
||||
Tensor values = expand_values_if_needed(values_);
|
||||
bool is_coalesced = is_coalesced_.value_or(false);
|
||||
const bool check_pinning = check_pinning_.value_or(true);
|
||||
|
||||
// the following checks are redundant because they are also checked in
|
||||
// SparseTensorImpl::set_indices_and_values_unsafe but we need to ensure them
|
||||
@ -397,13 +399,15 @@ void _validate_sparse_coo_tensor_args(
|
||||
"), but got ",
|
||||
size.size());
|
||||
|
||||
TORCH_CHECK(
|
||||
indices.is_pinned() == values.is_pinned(),
|
||||
"memory pinning of indices (=",
|
||||
indices.is_pinned(),
|
||||
") must match memory pinning of values (=",
|
||||
values.is_pinned(),
|
||||
")");
|
||||
if (check_pinning) {
|
||||
TORCH_CHECK(
|
||||
indices.is_pinned() == values.is_pinned(),
|
||||
"memory pinning of indices (=",
|
||||
indices.is_pinned(),
|
||||
") must match memory pinning of values (=",
|
||||
values.is_pinned(),
|
||||
")");
|
||||
}
|
||||
|
||||
// Check to make sure all indices are within the boundaries of `size`
|
||||
if (indices.numel() > 0) {
|
||||
|
Reference in New Issue
Block a user