[custom_op] Create a new torch._custom_op namespace (#101823)

torch/custom_op.py is getting long, and the autograd pieces are going to
make it even longer. I'm planning on just organizing the files under
a torch/_custom_op folder.

Note that the imports now look a bit crazy (from torch._custom_op.impl
import...) but they will look more OK when we figure out the plan to
make custom_op public (coming later).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101823
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/bdhirsh
This commit is contained in:
Richard Zou
2023-05-23 07:02:58 -07:00
committed by PyTorch MergeBot
parent 73d1be8e99
commit 8487105fae
6 changed files with 15 additions and 15 deletions

View File

@ -15,7 +15,7 @@ from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorR
log_input, capture_logs, capture_logs_with_logging_tensor_mode
from torch.utils._pytree import tree_map, tree_map_only
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode, _get_current_dispatch_mode_stack
from torch._custom_op import custom_op, CustomOp
from torch._custom_op.impl import custom_op, CustomOp
from torch.fx.experimental.proxy_tensor import make_fx
import typing
import collections
@ -416,11 +416,11 @@ class TestCustomOp(TestCase):
def tearDown(self):
import torch._custom_op
keys = list(torch._custom_op.global_registry.keys())
keys = list(torch._custom_op.impl.global_registry.keys())
for key in keys:
if not key.startswith(f'{TestCustomOp.test_ns}::'):
continue
torch._custom_op.global_registry[key]._destroy()
torch._custom_op.impl.global_registry[key]._destroy()
def test_invalid_schemas(self):
# function schmea validation goes through torchgen, so this is just a
@ -582,7 +582,7 @@ class TestCustomOp(TestCase):
return list(itertools.product(examples, examples)) + []
raise AssertionError(f"unsupported param type {typ}")
for typ in torch._custom_op.SUPPORTED_PARAM_TYPES:
for typ in torch._custom_op.impl.SUPPORTED_PARAM_TYPES:
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tensor, y: typ) -> Tensor:
...
@ -716,7 +716,7 @@ class TestCustomOp(TestCase):
op._destroy()
def test_reserved_ns(self):
from torch._custom_op import RESERVED_NS
from torch._custom_op.impl import RESERVED_NS
for ns in RESERVED_NS:
with self.assertRaisesRegex(ValueError, 'is a reserved namespace'):
@ -810,7 +810,7 @@ class TestCustomOp(TestCase):
def foo_impl(x):
return x.sin()
from torch._custom_op import SUPPORTED_DEVICE_TYPE_TO_KEY
from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY
for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys():
# Smoke test: should not raise error
@ -911,7 +911,7 @@ class TestCustomOp(TestCase):
@foo.impl_abstract()
def foo_meta(x):
ctx = torch._custom_op.get_ctx()
ctx = torch._custom_op.impl.get_ctx()
with self.assertRaisesRegex(ValueError, "greater than or equal to 2"):
ctx.create_unbacked_symint(min=1)
with self.assertRaisesRegex(ValueError, "greater than or equal to 2"):

View File

View File

@ -2,7 +2,7 @@ import contextlib
from typing import Sequence
import torch
from torch._custom_op import custom_op
from torch._custom_op.impl import custom_op
from torch.utils._content_store import ContentStoreReader
LOAD_TENSOR_READER = None

View File

@ -1301,11 +1301,11 @@ class FakeTensorMode(TorchDispatchMode):
# Users can register FakeTensor rules for custom operators
# Call them if they exist.
if func.name() in torch._custom_op.global_registry:
custom_op = torch._custom_op.global_registry[func.name()]
if func.name() in torch._custom_op.impl.global_registry:
custom_op = torch._custom_op.impl.global_registry[func.name()]
if custom_op is not None and custom_op._has_impl("abstract"):
ctx = torch._custom_op.AbstractImplCtx(self.shape_env, func)
with torch._custom_op.set_ctx_getter(lambda: ctx), self:
ctx = torch._custom_op.impl.AbstractImplCtx(self.shape_env, func)
with torch._custom_op.impl.set_ctx_getter(lambda: ctx), self:
result = custom_op._get_impl("abstract").func(*args, **kwargs)
return result

View File

@ -7,7 +7,7 @@ from torch.testing._internal.opinfo.core import (
)
from torch.testing._internal.common_dtype import all_types_and
import numpy as np
from torch._custom_op import custom_op
from torch._custom_op.impl import custom_op
from torch.testing._internal.autograd_function_db import (
sample_inputs_numpy_cube,
sample_inputs_numpy_mul,
@ -110,7 +110,7 @@ def numpy_nonzero_impl(x):
@numpy_nonzero.impl_abstract()
def numpy_nonzero_abstract(x):
ctx = torch._custom_op.get_ctx()
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.create_unbacked_symint()
shape = [x.dim(), i0]
result = x.new_empty(shape, dtype=torch.long)
@ -199,7 +199,7 @@ def numpy_nms_abstract(boxes, scores, iou_threshold):
assert boxes.shape == (N, 4)
assert scores.shape == (N,)
ctx = torch._custom_op.get_ctx()
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.create_unbacked_symint()
result = boxes.new_empty([i0, 4])
return result