Rowwise Prune op (Add the test to OSS run_test), Make the op private. (#46131)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46131

Refer to the title.

Test Plan: `buck test caffe2/test:pruning`

Reviewed By: raghuramank100

Differential Revision: D24230472

fbshipit-source-id: 8f0a83446c23fdf30d0313b8c3f5ff1a463b50c7
This commit is contained in:
Radhakrishnan Venkataramani
2021-01-29 06:06:12 -08:00
committed by Facebook GitHub Bot
parent ebe26b81d2
commit 3397919dcf
6 changed files with 9 additions and 6 deletions

View File

@ -29,7 +29,7 @@ class PruningOpTest(TestCase):
mask = self._generate_rowwise_mask(embedding_rows)
def get_pt_result(embedding_weights, mask, indices_type):
return torch.rowwise_prune(embedding_weights, mask, indices_type)
return torch._rowwise_prune(embedding_weights, mask, indices_type)
# Reference implementation.
def get_reference_result(embedding_weights, mask, indices_type):