mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rowwise Prune op (Add the test to OSS run_test), Make the op private. (#46131)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46131 Refer to the title. Test Plan: `buck test caffe2/test:pruning` Reviewed By: raghuramank100 Differential Revision: D24230472 fbshipit-source-id: 8f0a83446c23fdf30d0313b8c3f5ff1a463b50c7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ebe26b81d2
commit
3397919dcf
@ -80,9 +80,9 @@ std::tuple<Tensor, Tensor> _rowwise_prune_helper(
|
||||
// post pruning.
|
||||
// 2. An 1D tensor that contains the mapping between original weight row and
|
||||
// the corresponding row in the pruned weights tensor.
|
||||
std::tuple<Tensor, Tensor> rowwise_prune(const Tensor& weights,
|
||||
const Tensor& mask,
|
||||
ScalarType compressed_indices_dtype) {
|
||||
std::tuple<Tensor, Tensor> _rowwise_prune(const Tensor& weights,
|
||||
const Tensor& mask,
|
||||
ScalarType compressed_indices_dtype) {
|
||||
TORCH_CHECK(weights.ndimension() == 2,
|
||||
"'weights' should have 2 dimensions.");
|
||||
TORCH_CHECK(
|
||||
|
@ -1483,7 +1483,7 @@
|
||||
CPU: _embedding_bag_forward_only_cpu
|
||||
CUDA: _embedding_bag_forward_only_cuda
|
||||
|
||||
- func: rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)
|
||||
- func: _rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)
|
||||
|
||||
# row_stack is the alias of vstack
|
||||
- func: row_stack(Tensor[] tensors) -> Tensor
|
||||
|
@ -58,6 +58,7 @@ allow_list = [
|
||||
("prim::profile_optional", datetime.date(2021, 1, 31)),
|
||||
("aten::fake_quantize_per_tensor_affine_backward", datetime.date(2021, 2, 20)),
|
||||
("aten::fake_quantize_per_channel_affine_backward", datetime.date(2021, 2, 20)),
|
||||
("aten::rowwise_prune", datetime.date(9999, 1, 1)),
|
||||
]
|
||||
|
||||
def allow_listed(schema, allow_list):
|
||||
|
@ -65,6 +65,7 @@ TESTS = [
|
||||
'test_vulkan',
|
||||
'test_sparse',
|
||||
'test_quantization',
|
||||
'test_pruning_op',
|
||||
'test_spectral_ops',
|
||||
'test_serialization',
|
||||
'test_shape_ops',
|
||||
@ -263,6 +264,7 @@ SLOW_TESTS = [
|
||||
'distributed/test_jit_c10d',
|
||||
'distributed/test_c10d_spawn',
|
||||
'test_quantization',
|
||||
'test_pruning_op',
|
||||
'test_determination',
|
||||
'test_futures',
|
||||
'distributed/pipeline/sync/skip/test_api',
|
||||
|
@ -29,7 +29,7 @@ class PruningOpTest(TestCase):
|
||||
mask = self._generate_rowwise_mask(embedding_rows)
|
||||
|
||||
def get_pt_result(embedding_weights, mask, indices_type):
|
||||
return torch.rowwise_prune(embedding_weights, mask, indices_type)
|
||||
return torch._rowwise_prune(embedding_weights, mask, indices_type)
|
||||
|
||||
# Reference implementation.
|
||||
def get_reference_result(embedding_weights, mask, indices_type):
|
||||
|
@ -777,7 +777,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
|
||||
torch.round: lambda input, out=None: -1,
|
||||
torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
|
||||
torch.rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
|
||||
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
|
||||
torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
|
||||
torch.rsqrt: lambda input, out=None: -1,
|
||||
torch.rsub: lambda input, other, alpha=1: -1,
|
||||
|
Reference in New Issue
Block a user