Add nonzero_static() op to pytorch to unblock export (#97417)

Summary: Add new experimental python op (`torch.nonzero_static`) for export. There is NO cuda impl included in this PR

Example:

Say input tensor is `x = torch.tensor([[1, 0], [3, 2]])`

call regular `nonzero()` on x will give you a tensor `tensor([[0, 0], [1, 0], [1, 1])`
call `nonzero_static(x, size=4)` on x will give you a tensor `tensor([[0, 0], [1, 0], [1, 1], [fill_value, fill_value])` (padded)
call `nonzero_static(x, size=2)` on x will give you a tensor `tensor([[0, 0], [1, 0])` (truncated)

Test Plan:
**Unit Tests**
```
buck test @mode/dev-nosan //caffe2/test:test_dynamo -- 'caffe2/test:test_dynamo - test_export.py::ExportTests::test_export_with_nonzero_static' -- 'caffe2/test:test_dynamo - test_misc.py::MiscTests::test_nonzero_static'
```

**PT2 Export with `nonzero_static()`**
Example of `GraphModule` in the exported graph
```
def forward(self, x):
    arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
    nonzero_static_default = torch.ops.aten.nonzero_static.default(arg0, size = 4);  arg0 = None
    return pytree.tree_unflatten([nonzero_static_default], self._out_spec)
```

Differential Revision: D44324808

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97417
Approved by: https://github.com/ezyang
This commit is contained in:
Guang Yang
2023-04-11 05:13:36 +00:00
committed by PyTorch MergeBot
parent d4ce045cfc
commit c377a8590b
10 changed files with 350 additions and 0 deletions

View File

@ -3609,6 +3609,58 @@ See :func:`torch.nonzero`
""",
)
add_docstr_all(
"nonzero_static",
r"""
nonzero_static(input, *, size, fill_value=-1) -> Tensor
Returns a 2-D tensor where each row is the index for a non-zero value.
The returned Tensor has the same `torch.dtype` as `torch.nonzero()`.
Args:
input (Tensor): the input tensor to count non-zero elements.
Keyword args:
size (int): the size of non-zero elements expected to be included in the out
tensor. Pad the out tensor with `fill_value` if the `size` is larger
than total number of non-zero elements, truncate out tensor if `size`
is smaller. The size must be a non-negative integer.
fill_value (int): the value to fill the output tensor with when `size` is larger
than the total number of non-zero elements. Default is `-1` to represent
invalid index.
Example:
# Example 1: Padding
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
>>> static_size = 4
>>> t = torch.nonzero_static(input_tensor, size = static_size)
tensor([[ 0, 0],
[ 1, 0],
[ 1, 1],
[ -1, -1]], dtype=torch.int64)
# Example 2: Truncating
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
>>> static_size = 2
>>> t = torch.nonzero_static(input_tensor, size = static_size)
tensor([[ 0, 0],
[ 1, 0]], dtype=torch.int64)
# Example 3: 0 size
>>> input_tensor = torch.tensor([10])
>>> static_size = 0
>>> t = torch.nonzero_static(input_tensor, size = static_size)
tensor([], size=(0, 1), dtype=torch.int64)
# Example 4: 0 rank input
>>> input_tensor = torch.tensor(10)
>>> static_size = 2
>>> t = torch.nonzero_static(input_tensor, size = static_size)
tensor([], size=(2, 0), dtype=torch.int64)
""",
)
add_docstr_all(
"norm",
r"""