mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 23:45:05 +08:00
Add nonzero_static() op to pytorch to unblock export (#97417)
Summary: Add new experimental python op (`torch.nonzero_static`) for export. There is NO cuda impl included in this PR
Example:
Say input tensor is `x = torch.tensor([[1, 0], [3, 2]])`
call regular `nonzero()` on x will give you a tensor `tensor([[0, 0], [1, 0], [1, 1])`
call `nonzero_static(x, size=4)` on x will give you a tensor `tensor([[0, 0], [1, 0], [1, 1], [fill_value, fill_value])` (padded)
call `nonzero_static(x, size=2)` on x will give you a tensor `tensor([[0, 0], [1, 0])` (truncated)
Test Plan:
**Unit Tests**
```
buck test @mode/dev-nosan //caffe2/test:test_dynamo -- 'caffe2/test:test_dynamo - test_export.py::ExportTests::test_export_with_nonzero_static' -- 'caffe2/test:test_dynamo - test_misc.py::MiscTests::test_nonzero_static'
```
**PT2 Export with `nonzero_static()`**
Example of `GraphModule` in the exported graph
```
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
nonzero_static_default = torch.ops.aten.nonzero_static.default(arg0, size = 4); arg0 = None
return pytree.tree_unflatten([nonzero_static_default], self._out_spec)
```
Differential Revision: D44324808
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97417
Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
d4ce045cfc
commit
c377a8590b
@ -3609,6 +3609,58 @@ See :func:`torch.nonzero`
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"nonzero_static",
|
||||
r"""
|
||||
nonzero_static(input, *, size, fill_value=-1) -> Tensor
|
||||
|
||||
Returns a 2-D tensor where each row is the index for a non-zero value.
|
||||
The returned Tensor has the same `torch.dtype` as `torch.nonzero()`.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor to count non-zero elements.
|
||||
|
||||
Keyword args:
|
||||
size (int): the size of non-zero elements expected to be included in the out
|
||||
tensor. Pad the out tensor with `fill_value` if the `size` is larger
|
||||
than total number of non-zero elements, truncate out tensor if `size`
|
||||
is smaller. The size must be a non-negative integer.
|
||||
fill_value (int): the value to fill the output tensor with when `size` is larger
|
||||
than the total number of non-zero elements. Default is `-1` to represent
|
||||
invalid index.
|
||||
|
||||
Example:
|
||||
|
||||
# Example 1: Padding
|
||||
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
|
||||
>>> static_size = 4
|
||||
>>> t = torch.nonzero_static(input_tensor, size = static_size)
|
||||
tensor([[ 0, 0],
|
||||
[ 1, 0],
|
||||
[ 1, 1],
|
||||
[ -1, -1]], dtype=torch.int64)
|
||||
|
||||
# Example 2: Truncating
|
||||
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
|
||||
>>> static_size = 2
|
||||
>>> t = torch.nonzero_static(input_tensor, size = static_size)
|
||||
tensor([[ 0, 0],
|
||||
[ 1, 0]], dtype=torch.int64)
|
||||
|
||||
# Example 3: 0 size
|
||||
>>> input_tensor = torch.tensor([10])
|
||||
>>> static_size = 0
|
||||
>>> t = torch.nonzero_static(input_tensor, size = static_size)
|
||||
tensor([], size=(0, 1), dtype=torch.int64)
|
||||
|
||||
# Example 4: 0 rank input
|
||||
>>> input_tensor = torch.tensor(10)
|
||||
>>> static_size = 2
|
||||
>>> t = torch.nonzero_static(input_tensor, size = static_size)
|
||||
tensor([], size=(2, 0), dtype=torch.int64)
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"norm",
|
||||
r"""
|
||||
|
||||
Reference in New Issue
Block a user