mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Autogen Tags enum, and allow specifying tags while defining an op
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79322 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
bdcee8f995
commit
38350acf8f
1
BUCK.oss
1
BUCK.oss
@ -160,6 +160,7 @@ ATEN_EXPORTED_HEADERS = {
|
||||
"RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]",
|
||||
"core/TensorBody.h": ":gen_aten[core/TensorBody.h]",
|
||||
"core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]",
|
||||
"core/enum_tag.h": ":gen_aten[core/enum_tag.h]",
|
||||
}
|
||||
|
||||
cxx_library(
|
||||
|
@ -73,6 +73,7 @@ generated_cpu_cpp = [
|
||||
"aten/src/ATen/NativeMetaFunctions.h",
|
||||
"aten/src/ATen/RegistrationDeclarations.h",
|
||||
"aten/src/ATen/core/aten_interned_strings.h",
|
||||
"aten/src/ATen/core/enum_tag.h",
|
||||
"aten/src/ATen/core/TensorBody.h",
|
||||
"aten/src/ATen/core/TensorMethods.cpp",
|
||||
"aten/src/ATen/core/ATenOpList.cpp",
|
||||
|
@ -147,7 +147,7 @@ void Dispatcher::deregisterLibrary_(const std::string& ns) {
|
||||
libraries_.erase(ns);
|
||||
}
|
||||
|
||||
RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug) {
|
||||
RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags) {
|
||||
// we need a lock to avoid concurrent writes
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
@ -157,7 +157,7 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin
|
||||
TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
|
||||
" Each overload's schema should only be registered with a single call to def().",
|
||||
" Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
|
||||
op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug));
|
||||
op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug), tags);
|
||||
listeners_->callOnOperatorRegistered(op);
|
||||
|
||||
// NB: do not increment the counts until AFTER error checking
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/core/grad_mode.h>
|
||||
#include <ATen/core/enum_tag.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -177,7 +178,7 @@ public:
|
||||
* If a schema with the same operator name and overload name already exists,
|
||||
* this function will check that both schemas are exactly identical.
|
||||
*/
|
||||
RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug);
|
||||
RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags = {});
|
||||
|
||||
/**
|
||||
* Register a kernel to the dispatch table for an operator.
|
||||
@ -338,6 +339,19 @@ public:
|
||||
return operatorDef_->op.checkInvariants();
|
||||
}
|
||||
|
||||
c10::ArrayRef<at::Tag> getTags() const {
|
||||
return operatorDef_->op.getTags();
|
||||
}
|
||||
|
||||
bool hasTag(const at::Tag& tag) const {
|
||||
for(const auto& tag_: getTags()) {
|
||||
if (tag == tag_) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template<class FuncType>
|
||||
TypedOperatorHandle<FuncType> typed() const {
|
||||
// NB: This assert is not 100% sound: you can retrieve a typed() operator
|
||||
|
@ -19,6 +19,9 @@ namespace {
|
||||
OperatorEntry::OperatorEntry(OperatorName&& operator_name)
|
||||
: name_(std::move(operator_name))
|
||||
, schema_()
|
||||
#ifndef C10_MOBILE
|
||||
, tags_()
|
||||
#endif
|
||||
, dispatchTable_()
|
||||
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
|
||||
, kernels_()
|
||||
@ -57,7 +60,7 @@ const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const {
|
||||
return kernel;
|
||||
}
|
||||
|
||||
void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug) {
|
||||
void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector<at::Tag> tags) {
|
||||
TORCH_INTERNAL_ASSERT(!schema_.has_value());
|
||||
for (const auto& kernel : kernels_) {
|
||||
for (const auto &j : kernel.second) {
|
||||
@ -69,6 +72,9 @@ void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug)
|
||||
// NB: don't register schema until after we've checked everything!
|
||||
dispatchKeyExtractor_.registerSchema(schema);
|
||||
schema_ = AnnotatedSchema(std::move(schema), std::move(debug));
|
||||
#ifndef C10_MOBILE
|
||||
tags_ = std::move(tags);
|
||||
#endif
|
||||
}
|
||||
|
||||
void OperatorEntry::deregisterSchema() {
|
||||
@ -208,6 +214,14 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const std::vector<at::Tag>& OperatorEntry::getTags() const {
|
||||
#if defined C10_MOBILE
|
||||
TORCH_CHECK(false, "tags are not saved for Mobile");
|
||||
#else
|
||||
return tags_;
|
||||
#endif
|
||||
}
|
||||
|
||||
std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
|
||||
// [Note] DispatchTable computation
|
||||
// dispatchTable contains entries for runtime dispatch keys.
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <ATen/core/dispatch/OperatorOptions.h>
|
||||
#include <ATen/core/dispatch/CppSignature.h>
|
||||
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
|
||||
#include <ATen/core/enum_tag.h>
|
||||
|
||||
#include <list>
|
||||
#include <array>
|
||||
@ -98,7 +99,7 @@ public:
|
||||
// attempt to register a schema when one is already present or vice
|
||||
// versa that is an error. (Refcounting for the registrations is
|
||||
// handled in the OperatorHandle in Dispatcher)
|
||||
void registerSchema(FunctionSchema&&, std::string&& debug);
|
||||
void registerSchema(FunctionSchema&&, std::string&& debug, std::vector<at::Tag> tags = {});
|
||||
void deregisterSchema();
|
||||
|
||||
const OperatorName& operator_name() const {
|
||||
@ -205,12 +206,16 @@ public:
|
||||
bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
|
||||
// Returns true if kernel_ has entry for a particular key.
|
||||
bool hasKernelForDispatchKey(DispatchKey k) const;
|
||||
// Returns all the operator tags added at the time of registration
|
||||
const std::vector<at::Tag>& getTags() const;
|
||||
|
||||
private:
|
||||
|
||||
OperatorName name_;
|
||||
c10::optional<AnnotatedSchema> schema_;
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
std::vector<at::Tag> tags_;
|
||||
#endif
|
||||
std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
|
||||
DispatchKeyExtractor dispatchKeyExtractor_;
|
||||
|
||||
|
@ -89,7 +89,7 @@ Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, c
|
||||
// merge everything
|
||||
|
||||
#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
|
||||
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name) & {
|
||||
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector<at::Tag>& tags) & {
|
||||
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
|
||||
DEF_PRELUDE,
|
||||
"Cannot define an operator inside of a ", toString(kind_), " block. "
|
||||
@ -128,7 +128,8 @@ Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerDef(
|
||||
std::move(schema),
|
||||
debugString(file_, line_)
|
||||
debugString(file_, line_),
|
||||
tags
|
||||
)
|
||||
);
|
||||
return *this;
|
||||
|
@ -12,3 +12,7 @@
|
||||
desc: |
|
||||
This tag indicates if an operator's output's shape depends on input Tensor
|
||||
data.
|
||||
- tag: generated
|
||||
desc: |
|
||||
This tag indicates that the operator doesn't have an explicit entry in
|
||||
native_functions.yaml, and instead was generated automatically by the codegen.
|
||||
|
10
aten/src/ATen/templates/enum_tag.h
Normal file
10
aten/src/ATen/templates/enum_tag.h
Normal file
@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
// ${generated_comment}
|
||||
|
||||
namespace at {
|
||||
// Enum of valid tags obtained from the entries in tags.yaml
|
||||
enum class Tag {
|
||||
${enum_of_valid_tags}
|
||||
};
|
||||
}
|
@ -101,6 +101,7 @@ GENERATED_H_CORE = [
|
||||
"core/TensorBody.h",
|
||||
"MethodOperators.h",
|
||||
"core/aten_interned_strings.h",
|
||||
"core/enum_tag.h",
|
||||
]
|
||||
|
||||
GENERATED_H_CUDA = [
|
||||
@ -186,6 +187,7 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [
|
||||
"torch/csrc/autograd/generated/python_fft_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_linalg_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_return_types.cpp",
|
||||
"torch/csrc/autograd/generated/python_enum_tag.cpp",
|
||||
"torch/csrc/autograd/generated/python_sparse_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_special_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_torch_functions_0.cpp",
|
||||
|
@ -959,6 +959,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"):
|
||||
"torch/csrc/autograd/generated/python_nn_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_fft_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_linalg_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_enum_tag.cpp",
|
||||
"torch/csrc/autograd/generated/python_return_types.cpp",
|
||||
"torch/csrc/autograd/generated/python_sparse_functions.cpp",
|
||||
"torch/csrc/autograd/generated/python_special_functions.cpp",
|
||||
|
@ -401,6 +401,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_sparse_functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
|
||||
)
|
||||
|
||||
set(GENERATED_H_PYTHON
|
||||
@ -463,6 +464,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
"${TOOLS_PATH}/autograd/templates/python_sparse_functions.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/python_special_functions.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/python_return_types.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/python_enum_tag.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/variable_factories.h"
|
||||
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py.in"
|
||||
"${TOOLS_PATH}/autograd/deprecated.yaml"
|
||||
|
@ -613,6 +613,10 @@ Utilities
|
||||
vmap
|
||||
_assert
|
||||
|
||||
Operator Tags
|
||||
------------------------------------
|
||||
.. autoclass:: Tag
|
||||
:members:
|
||||
|
||||
.. Empty submodules added only for tracking.
|
||||
.. py:module:: torch.contrib
|
||||
|
@ -357,6 +357,7 @@ def get_aten_generated_files(enabled_backends):
|
||||
"core/TensorBody.h",
|
||||
"core/TensorMethods.cpp",
|
||||
"core/aten_interned_strings.h",
|
||||
"core/enum_tag.h",
|
||||
] + get_aten_derived_type_srcs(enabled_backends)
|
||||
|
||||
# This is tiresome. A better strategy would be to unconditionally
|
||||
|
@ -65,6 +65,7 @@ import torch.testing._internal.opinfo_helper as opinfo_helper
|
||||
from torch.testing._internal import composite_compliance
|
||||
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode
|
||||
|
||||
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
|
||||
torch.set_default_dtype(torch.float32)
|
||||
@ -1432,6 +1433,57 @@ class TestMathBits(TestCase):
|
||||
torch.is_complex,
|
||||
)
|
||||
|
||||
# input strides and size may have been altered due to the result of an inplace op
|
||||
def test_inplace_view(func, input, rs, input_size, input_strides):
|
||||
if func is None:
|
||||
return
|
||||
# TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out
|
||||
# which mutate not necessarily the first input.
|
||||
if isinstance(rs, torch.Tensor) and rs is input:
|
||||
unequal_size = rs.size() != input_size
|
||||
unequal_strides = rs.stride() != input_strides
|
||||
# resize_ should probably have inplace_view tag. Not adding the tag since it
|
||||
# breaks some codegen logic
|
||||
if (unequal_size or unequal_strides):
|
||||
if isinstance(func, torch._ops.OpOverloadPacket):
|
||||
func = func.default
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/78759
|
||||
if func is not torch.ops.aten.resize_.default:
|
||||
# TODO: use self.assertIn when we have separate tests for each tag
|
||||
assert torch.Tag.inplace_view in func.tags
|
||||
|
||||
# A mode that when enabled runs correctness checks to ensure
|
||||
# that operators have expected tags based on their input and
|
||||
# ouput tensor properties
|
||||
class TestTagsMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if isinstance(args[0], torch.Tensor):
|
||||
old_size = args[0].size()
|
||||
old_stride = args[0].stride()
|
||||
rs = func(*args, **kwargs)
|
||||
test_inplace_view(func, args[0], rs, old_size, old_stride)
|
||||
else:
|
||||
rs = func(*args, **kwargs)
|
||||
return rs
|
||||
|
||||
# Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags`
|
||||
class TestTags(TestCase):
|
||||
@onlyCPU
|
||||
@ops(ops_and_refs, dtypes=OpDTypes.any_one)
|
||||
def test_tags(self, device, dtype, op):
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
for sample in samples:
|
||||
# TODO: Test tags for ops that return a list of tensors
|
||||
input = sample.input
|
||||
if isinstance(input, torch.Tensor):
|
||||
old_size = input.size()
|
||||
old_stride = input.stride()
|
||||
with push_torch_dispatch_mode(TestTagsMode):
|
||||
rs = op(input, *sample.args, **sample.kwargs)
|
||||
# TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
|
||||
aten_name = op.aten_name if op.aten_name is not None else op.name
|
||||
opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
|
||||
test_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
|
||||
|
||||
|
||||
class TestRefsOpsInfo(TestCase):
|
||||
@ -1574,8 +1626,7 @@ instantiate_device_type_tests(TestCompositeCompliance, globals())
|
||||
instantiate_device_type_tests(TestMathBits, globals())
|
||||
instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
|
||||
instantiate_device_type_tests(TestFakeTensorNonErroring, globals())
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestTags, globals())
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -241,6 +241,11 @@ class TestPublicBindings(TestCase):
|
||||
"vitals_enabled",
|
||||
|
||||
"wait",
|
||||
"Tag",
|
||||
"inplace_view",
|
||||
"view_copy",
|
||||
"generated",
|
||||
"dynamic_output_shape",
|
||||
}
|
||||
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
|
||||
|
||||
|
@ -25,6 +25,7 @@ python_library(
|
||||
"templates/python_torch_functions.cpp",
|
||||
"templates/python_variable_methods.cpp",
|
||||
"templates/variable_factories.h",
|
||||
"templates/python_enum_tag.cpp",
|
||||
],
|
||||
visibility = ["PUBLIC"],
|
||||
deps = [
|
||||
|
@ -57,7 +57,7 @@ from torchgen.api.python import (
|
||||
namedtuple_fieldnames,
|
||||
signature,
|
||||
)
|
||||
from torchgen.gen import cpp_string, parse_native_yaml
|
||||
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
|
||||
from torchgen.context import with_native_function
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
@ -326,6 +326,17 @@ def gen(
|
||||
fm, functions, lambda fn: True, "python_return_types.cpp"
|
||||
)
|
||||
|
||||
valid_tags = parse_tags_yaml(tags_yaml_path)
|
||||
|
||||
def gen_tags_enum() -> Dict[str, str]:
|
||||
return {
|
||||
"enum_of_valid_tags": (
|
||||
"".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags])
|
||||
)
|
||||
}
|
||||
|
||||
fm.write("python_enum_tag.cpp", gen_tags_enum)
|
||||
|
||||
|
||||
def group_filter_overloads(
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
|
15
tools/autograd/templates/python_enum_tag.cpp
Normal file
15
tools/autograd/templates/python_enum_tag.cpp
Normal file
@ -0,0 +1,15 @@
|
||||
#include <torch/csrc/autograd/python_enum_tag.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <ATen/core/enum_tag.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
void initEnumTag(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
py::enum_<at::Tag>(m, "Tag")
|
||||
${enum_of_valid_tags}
|
||||
.export_values();
|
||||
m.doc() = "An Enum that contains tags that can be assigned to an operator registered in C++.";
|
||||
}
|
||||
}}
|
@ -202,7 +202,7 @@ def _jit_init() -> _bool: ...
|
||||
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
|
||||
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
|
||||
def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ...
|
||||
def _get_operation_overload(op_name: str, op_overload_name: str) -> Callable: ...
|
||||
def _get_operation_overload(op_name: str, op_overload_name: str) -> Tuple[Callable, List[Any]]: ...
|
||||
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
|
||||
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
|
||||
optimization_blocklist: Set[MobileOptimizerType],
|
||||
|
@ -28,10 +28,11 @@ def dl_open_guard():
|
||||
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
|
||||
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
|
||||
class OpOverload:
|
||||
def __init__(self, overloadpacket, op, schema):
|
||||
def __init__(self, overloadpacket, op, schema, tags):
|
||||
self._op = op
|
||||
self._schema = schema
|
||||
self._overloadpacket = overloadpacket
|
||||
self._tags = tags
|
||||
self._overloadname = 'default' if schema.overload_name == '' else schema.overload_name
|
||||
self.__name__ = "{}.{}".format(self._schema.name.split("::")[1], self._overloadname)
|
||||
self.__module__ = overloadpacket.__module__
|
||||
@ -65,6 +66,10 @@ class OpOverload:
|
||||
def op(self):
|
||||
return self._op
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
return self._tags
|
||||
|
||||
# TODO: add more methods to expose information about input and output arguments
|
||||
|
||||
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
|
||||
@ -123,10 +128,10 @@ class OpOverloadPacket:
|
||||
# This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
|
||||
use_key = '' if key == 'default' else key
|
||||
# TODO: disallow access to overloads registered by JIT
|
||||
op_ = torch._C._get_operation_overload(
|
||||
op_, tags = torch._C._get_operation_overload(
|
||||
self._qualified_op_name, use_key)
|
||||
schema = torch._C._get_schema(self._qualified_op_name, use_key)
|
||||
overload = OpOverload(self, op_, schema)
|
||||
overload = OpOverload(self, op_, schema, tags)
|
||||
# cache the overload object
|
||||
setattr(self, key, overload)
|
||||
return overload
|
||||
|
@ -41,6 +41,7 @@
|
||||
#include <torch/csrc/autograd/python_sparse_functions.h>
|
||||
#include <torch/csrc/autograd/python_special_functions.h>
|
||||
#include <torch/csrc/autograd/python_return_types.h>
|
||||
#include <torch/csrc/autograd/python_enum_tag.h>
|
||||
#include <torch/csrc/autograd/python_legacy_variable.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/multiprocessing/init.h>
|
||||
@ -828,6 +829,7 @@ PyObject* initModule() {
|
||||
// the export side of JIT, so this ONNX init needs to appear before the JIT
|
||||
// init.
|
||||
torch::onnx::initONNXBindings(module);
|
||||
torch::autograd::initEnumTag(module);
|
||||
torch::jit::initJITBindings(module);
|
||||
torch::monitor::initMonitorBindings(module);
|
||||
torch::impl::dispatch::initDispatchBindings(module);
|
||||
|
8
torch/csrc/autograd/python_enum_tag.h
Normal file
8
torch/csrc/autograd/python_enum_tag.h
Normal file
@ -0,0 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
void initEnumTag(PyObject* module);
|
||||
}} // namespace torch::autograd
|
@ -1363,7 +1363,7 @@ void initJITBindings(PyObject* module) {
|
||||
return _get_operation_for_overload_or_packet(
|
||||
{op}, symbol, args, kwargs, true);
|
||||
});
|
||||
return func;
|
||||
return py::make_tuple(func, py::cast(op->getTags().vec()));
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Found no matching operator overload");
|
||||
|
@ -160,6 +160,16 @@ struct TORCH_API Operator {
|
||||
});
|
||||
}
|
||||
|
||||
c10::ArrayRef<at::Tag> getTags() const {
|
||||
return op_.fold<c10::ArrayRef<at::Tag>>(
|
||||
[](const C10Operator& op) { return op.handle_.getTags(); },
|
||||
[](const JitOnlyOperator& op) {
|
||||
// Returns empty list of tags for JitOnlyOperators since it
|
||||
// doesn't save c10::OperatorHandle
|
||||
return c10::ArrayRef<at::Tag>();
|
||||
});
|
||||
}
|
||||
|
||||
bool isC10Op() const {
|
||||
return op_.is_left();
|
||||
}
|
||||
|
@ -65,6 +65,7 @@
|
||||
|
||||
// Just for inferFunctionSchemaFromFunctor
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/core/enum_tag.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
@ -594,12 +595,12 @@ class TORCH_API Library final {
|
||||
/// m.def("add(Tensor self, Tensor other) -> Tensor");
|
||||
/// }
|
||||
/// ```
|
||||
template <typename Schema>
|
||||
Library& def(Schema&& raw_schema) & {
|
||||
c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
|
||||
return _def(std::move(s));
|
||||
}
|
||||
|
||||
template <typename Schema>
|
||||
Library& def(Schema&& raw_schema, const std::vector<at::Tag>& tags = {}) & {
|
||||
c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
|
||||
return _def(std::move(s), nullptr, tags);
|
||||
}
|
||||
/// Define an operator for a schema and then register an implementation for
|
||||
/// it. This is typically what you would use if you aren't planning
|
||||
/// on making use of the dispatcher to structure your operator
|
||||
@ -813,7 +814,8 @@ class TORCH_API Library final {
|
||||
// public because we only implement & qualifier and not && qualifier
|
||||
Library& _def(
|
||||
c10::FunctionSchema&& schema,
|
||||
c10::OperatorName* out_name = nullptr) &;
|
||||
c10::OperatorName* out_name = nullptr,
|
||||
const std::vector<at::Tag>& tags = {}) &;
|
||||
Library& _def(
|
||||
c10::either<c10::OperatorName, c10::FunctionSchema>&&,
|
||||
CppFunction&& f) &;
|
||||
|
@ -11965,6 +11965,7 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
|
||||
# RuntimeError: Sparse CSR tensors do not have strides.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
|
||||
# RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),
|
||||
# RuntimeError: Sparse CSR tensors do not have strides
|
||||
@ -14123,8 +14124,10 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cpu'),
|
||||
DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo',
|
||||
'test_nnc_correctness', dtypes=(torch.bfloat16,)),
|
||||
DecorateInfo(unittest.skip("Works on some conifgs"), 'TestCudaFuserOpInfo',
|
||||
'test_nvfuser_correctness', dtypes=(torch.bfloat16,)),
|
||||
# RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet.
|
||||
# Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data()
|
||||
# to actually allocate memory
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_max_pool),
|
||||
OpInfo('nn.functional.max_pool2d',
|
||||
@ -17626,6 +17629,7 @@ op_db: List[OpInfo] = [
|
||||
# Allowed exception: sparse tensors don't have strides
|
||||
DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'),
|
||||
DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_backward'),
|
||||
DecorateInfo(unittest.skip("Allowed exception"), 'TestTags', 'test_tags'),
|
||||
# TODO: implement csr.to_sparse(sample_dim) where sampled_dim is 1.
|
||||
DecorateInfo(unittest.skip("csr.to_sparse(1) not implemented. Skipped!"),
|
||||
'TestSparseCSR', 'test_sparse_csr_consistency'),
|
||||
|
@ -152,6 +152,7 @@ class LineLoader(YamlLoader):
|
||||
|
||||
|
||||
_GLOBAL_PARSE_NATIVE_YAML_CACHE = {}
|
||||
_GLOBAL_PARSE_TAGS_YAML_CACHE = {}
|
||||
|
||||
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
|
||||
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
|
||||
@ -220,11 +221,13 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def parse_tags_yaml(path: str) -> Set[str]:
|
||||
# TODO: parse tags.yaml and create a tags database (a dict of tag name mapping to a Tag object)
|
||||
with open(path, "r") as f:
|
||||
es = yaml.load(f, Loader=LineLoader)
|
||||
valid_tags = parse_tags_yaml_struct(es, path=path)
|
||||
return valid_tags
|
||||
global _GLOBAL_PARSE_TAGS_YAML_CACHE
|
||||
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
|
||||
with open(path, "r") as f:
|
||||
es = yaml.load(f, Loader=LineLoader)
|
||||
_GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
|
||||
|
||||
return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
|
||||
|
||||
|
||||
def parse_native_yaml(
|
||||
@ -234,7 +237,6 @@ def parse_native_yaml(
|
||||
*,
|
||||
skip_native_fns_gen: bool = False,
|
||||
) -> ParsedYaml:
|
||||
# TODO: parse tags.yaml and create a tags database (a dict of tag name mapping to a Tag object)
|
||||
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
|
||||
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
|
||||
valid_tags = parse_tags_yaml(tags_yaml_path)
|
||||
@ -500,7 +502,8 @@ class RegisterSchema:
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
if not self.selector.is_native_function_selected(f):
|
||||
return None
|
||||
return f"m.def({cpp_string(str(f.func))});\n"
|
||||
tags = "{" + ", ".join([f"at::Tag::{tag}" for tag in f.tags]) + "}"
|
||||
return f"m.def({cpp_string(str(f.func))}, {tags});\n"
|
||||
|
||||
|
||||
# Generates Operators.h and Operators.cpp.
|
||||
@ -1711,6 +1714,7 @@ def gen_per_operator_headers(
|
||||
def gen_headers(
|
||||
*,
|
||||
native_functions: Sequence[NativeFunction],
|
||||
valid_tags: Set[str],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||
static_dispatch_idx: List[BackendIndex],
|
||||
@ -1838,6 +1842,11 @@ def gen_headers(
|
||||
|
||||
core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
|
||||
|
||||
def gen_tags_enum() -> Dict[str, str]:
|
||||
return {"enum_of_valid_tags": (",\n".join([f"{tag}" for tag in valid_tags]))}
|
||||
|
||||
core_fm.write("enum_tag.h", gen_tags_enum)
|
||||
|
||||
|
||||
def gen_source_files(
|
||||
*,
|
||||
@ -2441,6 +2450,7 @@ def main() -> None:
|
||||
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
|
||||
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
|
||||
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
|
||||
native_functions, backend_indices = (
|
||||
parsed_yaml.native_functions,
|
||||
parsed_yaml.backend_indices,
|
||||
@ -2546,6 +2556,7 @@ def main() -> None:
|
||||
if "headers" in options.generate:
|
||||
gen_headers(
|
||||
native_functions=native_functions,
|
||||
valid_tags=valid_tags,
|
||||
grouped_native_functions=grouped_native_functions,
|
||||
structured_native_functions=structured_native_functions,
|
||||
static_dispatch_idx=static_dispatch_idx,
|
||||
|
Reference in New Issue
Block a user