Rowwise Prune op (Add the test to OSS run_test), Make the op private. (#46131)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46131

Refer to the title.

Test Plan: `buck test caffe2/test:pruning`

Reviewed By: raghuramank100

Differential Revision: D24230472

fbshipit-source-id: 8f0a83446c23fdf30d0313b8c3f5ff1a463b50c7
This commit is contained in:
Radhakrishnan Venkataramani
2021-01-29 06:06:12 -08:00
committed by Facebook GitHub Bot
parent ebe26b81d2
commit 3397919dcf
6 changed files with 9 additions and 6 deletions

View File

@ -80,9 +80,9 @@ std::tuple<Tensor, Tensor> _rowwise_prune_helper(
// post pruning.
// 2. An 1D tensor that contains the mapping between original weight row and
// the corresponding row in the pruned weights tensor.
std::tuple<Tensor, Tensor> rowwise_prune(const Tensor& weights,
const Tensor& mask,
ScalarType compressed_indices_dtype) {
std::tuple<Tensor, Tensor> _rowwise_prune(const Tensor& weights,
const Tensor& mask,
ScalarType compressed_indices_dtype) {
TORCH_CHECK(weights.ndimension() == 2,
"'weights' should have 2 dimensions.");
TORCH_CHECK(

View File

@ -1483,7 +1483,7 @@
CPU: _embedding_bag_forward_only_cpu
CUDA: _embedding_bag_forward_only_cuda
- func: rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)
- func: _rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)
# row_stack is the alias of vstack
- func: row_stack(Tensor[] tensors) -> Tensor

View File

@ -58,6 +58,7 @@ allow_list = [
("prim::profile_optional", datetime.date(2021, 1, 31)),
("aten::fake_quantize_per_tensor_affine_backward", datetime.date(2021, 2, 20)),
("aten::fake_quantize_per_channel_affine_backward", datetime.date(2021, 2, 20)),
("aten::rowwise_prune", datetime.date(9999, 1, 1)),
]
def allow_listed(schema, allow_list):

View File

@ -65,6 +65,7 @@ TESTS = [
'test_vulkan',
'test_sparse',
'test_quantization',
'test_pruning_op',
'test_spectral_ops',
'test_serialization',
'test_shape_ops',
@ -263,6 +264,7 @@ SLOW_TESTS = [
'distributed/test_jit_c10d',
'distributed/test_c10d_spawn',
'test_quantization',
'test_pruning_op',
'test_determination',
'test_futures',
'distributed/pipeline/sync/skip/test_api',

View File

@ -29,7 +29,7 @@ class PruningOpTest(TestCase):
mask = self._generate_rowwise_mask(embedding_rows)
def get_pt_result(embedding_weights, mask, indices_type):
return torch.rowwise_prune(embedding_weights, mask, indices_type)
return torch._rowwise_prune(embedding_weights, mask, indices_type)
# Reference implementation.
def get_reference_result(embedding_weights, mask, indices_type):

View File

@ -777,7 +777,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
torch.round: lambda input, out=None: -1,
torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
torch.rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
torch.rsqrt: lambda input, out=None: -1,
torch.rsub: lambda input, other, alpha=1: -1,