mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: We've made the following changes: - The new way to use the API is `m.impl_abstract_pystub(module, context)`. Every subsequent m.def of an op inside the TORCH_LIBRARY block gives the op the `impl_abstract_pystub`. - Added a mechanism to determine if an operator was defined in Python or C++. Library.define in Python appends the op to a global set, which is analogous to what we do for tracking Library.impl. - If someone does `torch.library.impl_abstract` in Python for an operator, then we require that it has an `impl_abstract_pystub` specified and we also check that the module in the `impl_abstract_pystub` is the same as the module where the call to `torch.library.impl_abstract` exists. - Unfortunately we can't check the "context" (which is the buck target on buck-based systems) because buck sits above us. Test Plan: - existing tests Differential Revision: D50972148 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112851 Approved by: https://github.com/ezyang
14 lines
345 B
Python
14 lines
345 B
Python
import torch
|
|
from model import get_custom_op_library_path
|
|
|
|
torch.ops.load_library(get_custom_op_library_path())
|
|
|
|
|
|
@torch.library.impl_abstract("custom::nonzero")
|
|
def nonzero_abstract(x):
|
|
n = x.dim()
|
|
ctx = torch.library.get_ctx()
|
|
nnz = ctx.create_unbacked_symint()
|
|
shape = [nnz, n]
|
|
return x.new_empty(shape, dtype=torch.long)
|