mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129758 Approved by: https://github.com/ezyang
16 lines
347 B
Python
16 lines
347 B
Python
from model import get_custom_op_library_path
|
|
|
|
import torch
|
|
|
|
|
|
torch.ops.load_library(get_custom_op_library_path())
|
|
|
|
|
|
@torch.library.impl_abstract("custom::nonzero")
|
|
def nonzero_abstract(x):
|
|
n = x.dim()
|
|
ctx = torch.library.get_ctx()
|
|
nnz = ctx.create_unbacked_symint()
|
|
shape = [nnz, n]
|
|
return x.new_empty(shape, dtype=torch.long)
|