mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Rowwise Prune PyTorch op (#42708)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42708 Add rowwise prune pytorch op. This operator introduces sparsity to the 'weights' matrix with the help of the importance indicator 'mask'. A row is considered important and not pruned if the mask value for that particular row is 1(True) and not important otherwise. Test Plan: buck test caffe2/torch/fb/sparsenn:test -- rowwise_prune buck test caffe2/test:pruning Reviewed By: supriyar Differential Revision: D22849432 fbshipit-source-id: 456f4f77c04158cdc3830b2e69de541c7272a46d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3a0e35c9f2
commit
8032dbc117
106
aten/src/ATen/native/RowwisePrune.cpp
Normal file
106
aten/src/ATen/native/RowwisePrune.cpp
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright 2004-present Facebook. All Rights Reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename input_t>
|
||||
std::tuple<Tensor, Tensor> _rowwise_prune_helper(
|
||||
const Tensor& weights, const Tensor& mask,
|
||||
ScalarType compressed_indices_dtype) {
|
||||
int num_non_masked_rows = 0;
|
||||
auto mask_contig = mask.contiguous();
|
||||
auto mask_data = mask_contig.data_ptr<bool>();
|
||||
for (int i = 0; i < mask.numel(); ++i) {
|
||||
num_non_masked_rows += (((mask_data[i] == true)) ? 1 : 0);
|
||||
}
|
||||
int num_cols = weights.size(1);
|
||||
auto pruned_2d_tensor = at::empty({num_non_masked_rows, num_cols},
|
||||
weights.options());
|
||||
auto compressed_indices_mapping = at::empty({mask.numel()},
|
||||
compressed_indices_dtype);
|
||||
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
weights.scalar_type(),
|
||||
"rowwise_prune_helper", [&]() {
|
||||
auto* pruned_2d_tensor_data = pruned_2d_tensor.data_ptr<scalar_t>();
|
||||
auto compressed_indices_mapping_data =
|
||||
compressed_indices_mapping.data_ptr<input_t>();
|
||||
auto weights_data = weights.data_ptr<scalar_t>();
|
||||
int last_row_kept = 0;
|
||||
for (int i = 0; i < mask.numel(); i++) {
|
||||
if (mask_data[i]) {
|
||||
memcpy(pruned_2d_tensor_data + last_row_kept * num_cols,
|
||||
weights_data + i * num_cols,
|
||||
num_cols * sizeof (scalar_t));
|
||||
compressed_indices_mapping_data[i] = last_row_kept;
|
||||
last_row_kept++;
|
||||
} else {
|
||||
compressed_indices_mapping_data[i] = -1;
|
||||
}
|
||||
}
|
||||
});
|
||||
return std::tuple<Tensor, Tensor>(pruned_2d_tensor,
|
||||
compressed_indices_mapping);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
// This operator introduces sparsity to the 'weights' matrix with the help
|
||||
// of the importance indicator 'mask'.
|
||||
//
|
||||
// A row is considered important and not pruned if the mask value for that
|
||||
// particular row is 1(True) and not important otherwise.
|
||||
//
|
||||
// This operator doesn't zero out the pruned rows in-place. Instead, it
|
||||
// returns a tuple that contains a pruned weights tensor as well as a map that
|
||||
// can be used to look up the original row in the pruned weights tensor.
|
||||
// We refer this map as 'compressed indices map' going forward.
|
||||
|
||||
// The 'compressed indices map' is an 1D tensor that contains one entry per
|
||||
// original row in 'weights'. The array index is the index for the original
|
||||
// non-pruned weight tensor and the value would be the re-mapped index in the
|
||||
// pruned weights tensor. If the value for a index is -1, it means the
|
||||
// corresponding row has been pruned from the original weight tensor.
|
||||
|
||||
// Arguments:
|
||||
// 'weights' - two dimensional matrix that needs to be prune.
|
||||
// 'mask' - 1D boolean tensor that represents whether a row is important or
|
||||
// not. A mask value of 1 means the row should be kept and 0 means the row
|
||||
// should be pruned.
|
||||
//
|
||||
// Returns:
|
||||
// A tuple containing two tensors,
|
||||
// 1. A pruned weight tensor that contains only the weights that are preserved
|
||||
// post pruning.
|
||||
// 2. An 1D tensor that contains the mapping between original weight row and
|
||||
// the corresponding row in the pruned weights tensor.
|
||||
std::tuple<Tensor, Tensor> rowwise_prune(const Tensor& weights,
|
||||
const Tensor& mask,
|
||||
ScalarType compressed_indices_dtype) {
|
||||
TORCH_CHECK(weights.ndimension() == 2,
|
||||
"'weights' should have 2 dimensions.");
|
||||
TORCH_CHECK(
|
||||
mask.numel() == weights.size(0),
|
||||
"Number of elements in 'mask' should be equivalent to the "
|
||||
"number of rows in 'weights'."
|
||||
)
|
||||
TORCH_CHECK(
|
||||
compressed_indices_dtype == ScalarType::Int ||
|
||||
compressed_indices_dtype == ScalarType::Long,
|
||||
"compressed_indices_dtype should be either int(int32) or long(int64).");
|
||||
|
||||
if (compressed_indices_dtype == at::ScalarType::Int) {
|
||||
return _rowwise_prune_helper<int32_t>(weights, mask,
|
||||
compressed_indices_dtype);
|
||||
}
|
||||
return _rowwise_prune_helper<int64_t>(weights, mask,
|
||||
compressed_indices_dtype);
|
||||
}
|
||||
|
||||
}} // namesapce at::native
|
@ -1236,6 +1236,9 @@
|
||||
CPU: _embedding_bag_forward_only_cpu
|
||||
CUDA: _embedding_bag_forward_only_cuda
|
||||
|
||||
- func: rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)
|
||||
use_c10_dispatcher: full
|
||||
|
||||
|
78
test/test_pruning_op.py
Normal file
78
test/test_pruning_op.py
Normal file
@ -0,0 +1,78 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import given
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
import torch.testing._internal.hypothesis_utils as hu
|
||||
hu.assert_deadline_disabled()
|
||||
|
||||
|
||||
class PruningOpTest(TestCase):
|
||||
|
||||
# Generate rowwise mask vector based on indicator and threshold value.
|
||||
# indicator is a vector that contains one value per weight row and it
|
||||
# represents the importance of a row.
|
||||
# We mask a row if its indicator value is less than the threshold.
|
||||
def _generate_rowwise_mask(self, embedding_rows):
|
||||
indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
|
||||
threshold = np.random.random_sample()
|
||||
mask = torch.BoolTensor([True if val >= threshold else False for val in indicator])
|
||||
return mask
|
||||
|
||||
def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):
|
||||
embedding_weights = None
|
||||
if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype)
|
||||
else:
|
||||
embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype)
|
||||
mask = self._generate_rowwise_mask(embedding_rows)
|
||||
|
||||
def get_pt_result(embedding_weights, mask, indices_type):
|
||||
return torch.rowwise_prune(embedding_weights, mask, indices_type)
|
||||
|
||||
# Reference implementation.
|
||||
def get_reference_result(embedding_weights, mask, indices_type):
|
||||
num_embeddings = mask.size()[0]
|
||||
compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type)
|
||||
pruned_weights_out = embedding_weights[mask[:]]
|
||||
idx = 0
|
||||
for i in range(mask.size()[0]):
|
||||
if mask[i]:
|
||||
compressed_idx_out[i] = idx
|
||||
idx = idx + 1
|
||||
else:
|
||||
compressed_idx_out[i] = -1
|
||||
return (pruned_weights_out, compressed_idx_out)
|
||||
|
||||
pt_pruned_weights, pt_compressed_indices_map = get_pt_result(
|
||||
embedding_weights, mask, indices_type)
|
||||
ref_pruned_weights, ref_compressed_indices_map = get_reference_result(
|
||||
embedding_weights, mask, indices_type)
|
||||
|
||||
torch.testing.assert_allclose(pt_pruned_weights, ref_pruned_weights)
|
||||
self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map)
|
||||
self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
|
||||
|
||||
|
||||
@given(
|
||||
embedding_rows=st.integers(1, 100),
|
||||
embedding_dims=st.integers(1, 100),
|
||||
weights_dtype=st.sampled_from([torch.float64, torch.float32,
|
||||
torch.float16, torch.int8,
|
||||
torch.int16, torch.int32, torch.int64])
|
||||
)
|
||||
def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
|
||||
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
|
||||
|
||||
|
||||
@given(
|
||||
embedding_rows=st.integers(1, 100),
|
||||
embedding_dims=st.integers(1, 100),
|
||||
weights_dtype=st.sampled_from([torch.float64, torch.float32,
|
||||
torch.float16, torch.int8,
|
||||
torch.int16, torch.int32, torch.int64])
|
||||
)
|
||||
def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
|
||||
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype)
|
@ -672,6 +672,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.roll: lambda input, shifts, dims=None: -1,
|
||||
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
|
||||
torch.round: lambda input, out=None: -1,
|
||||
torch.rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
|
||||
torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
|
||||
torch.rsqrt: lambda input, out=None: -1,
|
||||
torch.rsub: lambda input, other, alpha=1: -1,
|
||||
|
Reference in New Issue
Block a user