mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
This reverts commit 6ae4e3a8d249a96d9a8bbfba389d0509783e11e1. Reverted https://github.com/pytorch/pytorch/pull/112851 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/112851#issuecomment-1799539354))
15 lines
369 B
Python
15 lines
369 B
Python
import torch
|
|
import torch._custom_ops as library
|
|
from model import get_custom_op_library_path
|
|
|
|
torch.ops.load_library(get_custom_op_library_path())
|
|
|
|
|
|
@library.impl_abstract("custom::nonzero")
|
|
def nonzero_abstract(x):
|
|
n = x.dim()
|
|
ctx = library.get_ctx()
|
|
nnz = ctx.create_unbacked_symint()
|
|
shape = [nnz, n]
|
|
return x.new_empty(shape, dtype=torch.long)
|