mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[custom_op] Create a new torch._custom_op namespace (#101823)
torch/custom_op.py is getting long, and the autograd pieces are going to make it even longer. I'm planning on just organizing the files under a torch/_custom_op folder. Note that the imports now look a bit crazy (from torch._custom_op.impl import...) but they will look more OK when we figure out the plan to make custom_op public (coming later). Pull Request resolved: https://github.com/pytorch/pytorch/pull/101823 Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
73d1be8e99
commit
8487105fae
@ -15,7 +15,7 @@ from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorR
|
||||
log_input, capture_logs, capture_logs_with_logging_tensor_mode
|
||||
from torch.utils._pytree import tree_map, tree_map_only
|
||||
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode, _get_current_dispatch_mode_stack
|
||||
from torch._custom_op import custom_op, CustomOp
|
||||
from torch._custom_op.impl import custom_op, CustomOp
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
import typing
|
||||
import collections
|
||||
@ -416,11 +416,11 @@ class TestCustomOp(TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
import torch._custom_op
|
||||
keys = list(torch._custom_op.global_registry.keys())
|
||||
keys = list(torch._custom_op.impl.global_registry.keys())
|
||||
for key in keys:
|
||||
if not key.startswith(f'{TestCustomOp.test_ns}::'):
|
||||
continue
|
||||
torch._custom_op.global_registry[key]._destroy()
|
||||
torch._custom_op.impl.global_registry[key]._destroy()
|
||||
|
||||
def test_invalid_schemas(self):
|
||||
# function schmea validation goes through torchgen, so this is just a
|
||||
@ -582,7 +582,7 @@ class TestCustomOp(TestCase):
|
||||
return list(itertools.product(examples, examples)) + []
|
||||
raise AssertionError(f"unsupported param type {typ}")
|
||||
|
||||
for typ in torch._custom_op.SUPPORTED_PARAM_TYPES:
|
||||
for typ in torch._custom_op.impl.SUPPORTED_PARAM_TYPES:
|
||||
@custom_op(f'{TestCustomOp.test_ns}::foo')
|
||||
def foo(x: Tensor, y: typ) -> Tensor:
|
||||
...
|
||||
@ -716,7 +716,7 @@ class TestCustomOp(TestCase):
|
||||
op._destroy()
|
||||
|
||||
def test_reserved_ns(self):
|
||||
from torch._custom_op import RESERVED_NS
|
||||
from torch._custom_op.impl import RESERVED_NS
|
||||
|
||||
for ns in RESERVED_NS:
|
||||
with self.assertRaisesRegex(ValueError, 'is a reserved namespace'):
|
||||
@ -810,7 +810,7 @@ class TestCustomOp(TestCase):
|
||||
def foo_impl(x):
|
||||
return x.sin()
|
||||
|
||||
from torch._custom_op import SUPPORTED_DEVICE_TYPE_TO_KEY
|
||||
from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY
|
||||
|
||||
for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys():
|
||||
# Smoke test: should not raise error
|
||||
@ -911,7 +911,7 @@ class TestCustomOp(TestCase):
|
||||
|
||||
@foo.impl_abstract()
|
||||
def foo_meta(x):
|
||||
ctx = torch._custom_op.get_ctx()
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
with self.assertRaisesRegex(ValueError, "greater than or equal to 2"):
|
||||
ctx.create_unbacked_symint(min=1)
|
||||
with self.assertRaisesRegex(ValueError, "greater than or equal to 2"):
|
||||
|
0
torch/_custom_op/__init__.py
Normal file
0
torch/_custom_op/__init__.py
Normal file
@ -2,7 +2,7 @@ import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from torch._custom_op import custom_op
|
||||
from torch._custom_op.impl import custom_op
|
||||
from torch.utils._content_store import ContentStoreReader
|
||||
|
||||
LOAD_TENSOR_READER = None
|
||||
|
@ -1301,11 +1301,11 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
|
||||
# Users can register FakeTensor rules for custom operators
|
||||
# Call them if they exist.
|
||||
if func.name() in torch._custom_op.global_registry:
|
||||
custom_op = torch._custom_op.global_registry[func.name()]
|
||||
if func.name() in torch._custom_op.impl.global_registry:
|
||||
custom_op = torch._custom_op.impl.global_registry[func.name()]
|
||||
if custom_op is not None and custom_op._has_impl("abstract"):
|
||||
ctx = torch._custom_op.AbstractImplCtx(self.shape_env, func)
|
||||
with torch._custom_op.set_ctx_getter(lambda: ctx), self:
|
||||
ctx = torch._custom_op.impl.AbstractImplCtx(self.shape_env, func)
|
||||
with torch._custom_op.impl.set_ctx_getter(lambda: ctx), self:
|
||||
result = custom_op._get_impl("abstract").func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
|
@ -7,7 +7,7 @@ from torch.testing._internal.opinfo.core import (
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types_and
|
||||
import numpy as np
|
||||
from torch._custom_op import custom_op
|
||||
from torch._custom_op.impl import custom_op
|
||||
from torch.testing._internal.autograd_function_db import (
|
||||
sample_inputs_numpy_cube,
|
||||
sample_inputs_numpy_mul,
|
||||
@ -110,7 +110,7 @@ def numpy_nonzero_impl(x):
|
||||
|
||||
@numpy_nonzero.impl_abstract()
|
||||
def numpy_nonzero_abstract(x):
|
||||
ctx = torch._custom_op.get_ctx()
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
i0 = ctx.create_unbacked_symint()
|
||||
shape = [x.dim(), i0]
|
||||
result = x.new_empty(shape, dtype=torch.long)
|
||||
@ -199,7 +199,7 @@ def numpy_nms_abstract(boxes, scores, iou_threshold):
|
||||
assert boxes.shape == (N, 4)
|
||||
assert scores.shape == (N,)
|
||||
|
||||
ctx = torch._custom_op.get_ctx()
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
i0 = ctx.create_unbacked_symint()
|
||||
result = boxes.new_empty([i0, 4])
|
||||
return result
|
||||
|
Reference in New Issue
Block a user