Autogen Tags enum, and allow specifying tags while defining an op

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79322

Approved by: https://github.com/albanD
This commit is contained in:
anjali411
2022-06-10 21:48:56 +00:00
committed by PyTorch MergeBot
parent bdcee8f995
commit 38350acf8f
28 changed files with 216 additions and 31 deletions

View File

@ -160,6 +160,7 @@ ATEN_EXPORTED_HEADERS = {
"RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]",
"core/TensorBody.h": ":gen_aten[core/TensorBody.h]",
"core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]",
"core/enum_tag.h": ":gen_aten[core/enum_tag.h]",
}
cxx_library(

View File

@ -73,6 +73,7 @@ generated_cpu_cpp = [
"aten/src/ATen/NativeMetaFunctions.h",
"aten/src/ATen/RegistrationDeclarations.h",
"aten/src/ATen/core/aten_interned_strings.h",
"aten/src/ATen/core/enum_tag.h",
"aten/src/ATen/core/TensorBody.h",
"aten/src/ATen/core/TensorMethods.cpp",
"aten/src/ATen/core/ATenOpList.cpp",

View File

@ -147,7 +147,7 @@ void Dispatcher::deregisterLibrary_(const std::string& ns) {
libraries_.erase(ns);
}
RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug) {
RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags) {
// we need a lock to avoid concurrent writes
std::lock_guard<std::mutex> lock(mutex_);
@ -157,7 +157,7 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin
TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
" Each overload's schema should only be registered with a single call to def().",
" Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug));
op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug), tags);
listeners_->callOnOperatorRegistered(op);
// NB: do not increment the counts until AFTER error checking

View File

@ -14,6 +14,7 @@
#include <type_traits>
#include <ATen/core/grad_mode.h>
#include <ATen/core/enum_tag.h>
namespace c10 {
@ -177,7 +178,7 @@ public:
* If a schema with the same operator name and overload name already exists,
* this function will check that both schemas are exactly identical.
*/
RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug);
RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags = {});
/**
* Register a kernel to the dispatch table for an operator.
@ -338,6 +339,19 @@ public:
return operatorDef_->op.checkInvariants();
}
c10::ArrayRef<at::Tag> getTags() const {
return operatorDef_->op.getTags();
}
bool hasTag(const at::Tag& tag) const {
for(const auto& tag_: getTags()) {
if (tag == tag_) {
return true;
}
}
return false;
}
template<class FuncType>
TypedOperatorHandle<FuncType> typed() const {
// NB: This assert is not 100% sound: you can retrieve a typed() operator

View File

@ -19,6 +19,9 @@ namespace {
OperatorEntry::OperatorEntry(OperatorName&& operator_name)
: name_(std::move(operator_name))
, schema_()
#ifndef C10_MOBILE
, tags_()
#endif
, dispatchTable_()
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
, kernels_()
@ -57,7 +60,7 @@ const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const {
return kernel;
}
void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug) {
void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug, std::vector<at::Tag> tags) {
TORCH_INTERNAL_ASSERT(!schema_.has_value());
for (const auto& kernel : kernels_) {
for (const auto &j : kernel.second) {
@ -69,6 +72,9 @@ void OperatorEntry::registerSchema(FunctionSchema&& schema, std::string&& debug)
// NB: don't register schema until after we've checked everything!
dispatchKeyExtractor_.registerSchema(schema);
schema_ = AnnotatedSchema(std::move(schema), std::move(debug));
#ifndef C10_MOBILE
tags_ = std::move(tags);
#endif
}
void OperatorEntry::deregisterSchema() {
@ -208,6 +214,14 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat
return nullptr;
}
const std::vector<at::Tag>& OperatorEntry::getTags() const {
#if defined C10_MOBILE
TORCH_CHECK(false, "tags are not saved for Mobile");
#else
return tags_;
#endif
}
std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
// [Note] DispatchTable computation
// dispatchTable contains entries for runtime dispatch keys.

View File

@ -13,6 +13,7 @@
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/core/dispatch/CppSignature.h>
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
#include <ATen/core/enum_tag.h>
#include <list>
#include <array>
@ -98,7 +99,7 @@ public:
// attempt to register a schema when one is already present or vice
// versa that is an error. (Refcounting for the registrations is
// handled in the OperatorHandle in Dispatcher)
void registerSchema(FunctionSchema&&, std::string&& debug);
void registerSchema(FunctionSchema&&, std::string&& debug, std::vector<at::Tag> tags = {});
void deregisterSchema();
const OperatorName& operator_name() const {
@ -205,12 +206,16 @@ public:
bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
// Returns true if kernel_ has entry for a particular key.
bool hasKernelForDispatchKey(DispatchKey k) const;
// Returns all the operator tags added at the time of registration
const std::vector<at::Tag>& getTags() const;
private:
OperatorName name_;
c10::optional<AnnotatedSchema> schema_;
#ifndef C10_MOBILE
std::vector<at::Tag> tags_;
#endif
std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;

View File

@ -89,7 +89,7 @@ Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, c
// merge everything
#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name) & {
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector<at::Tag>& tags) & {
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
DEF_PRELUDE,
"Cannot define an operator inside of a ", toString(kind_), " block. "
@ -128,7 +128,8 @@ Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name
registrars_.emplace_back(
c10::Dispatcher::singleton().registerDef(
std::move(schema),
debugString(file_, line_)
debugString(file_, line_),
tags
)
);
return *this;

View File

@ -12,3 +12,7 @@
desc: |
This tag indicates if an operator's output's shape depends on input Tensor
data.
- tag: generated
desc: |
This tag indicates that the operator doesn't have an explicit entry in
native_functions.yaml, and instead was generated automatically by the codegen.

View File

@ -0,0 +1,10 @@
#pragma once
// ${generated_comment}
namespace at {
// Enum of valid tags obtained from the entries in tags.yaml
enum class Tag {
${enum_of_valid_tags}
};
}

View File

@ -101,6 +101,7 @@ GENERATED_H_CORE = [
"core/TensorBody.h",
"MethodOperators.h",
"core/aten_interned_strings.h",
"core/enum_tag.h",
]
GENERATED_H_CUDA = [
@ -186,6 +187,7 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [
"torch/csrc/autograd/generated/python_fft_functions.cpp",
"torch/csrc/autograd/generated/python_linalg_functions.cpp",
"torch/csrc/autograd/generated/python_return_types.cpp",
"torch/csrc/autograd/generated/python_enum_tag.cpp",
"torch/csrc/autograd/generated/python_sparse_functions.cpp",
"torch/csrc/autograd/generated/python_special_functions.cpp",
"torch/csrc/autograd/generated/python_torch_functions_0.cpp",

View File

@ -959,6 +959,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"):
"torch/csrc/autograd/generated/python_nn_functions.cpp",
"torch/csrc/autograd/generated/python_fft_functions.cpp",
"torch/csrc/autograd/generated/python_linalg_functions.cpp",
"torch/csrc/autograd/generated/python_enum_tag.cpp",
"torch/csrc/autograd/generated/python_return_types.cpp",
"torch/csrc/autograd/generated/python_sparse_functions.cpp",
"torch/csrc/autograd/generated/python_special_functions.cpp",

View File

@ -401,6 +401,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_sparse_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
)
set(GENERATED_H_PYTHON
@ -463,6 +464,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TOOLS_PATH}/autograd/templates/python_sparse_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_special_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_return_types.cpp"
"${TOOLS_PATH}/autograd/templates/python_enum_tag.cpp"
"${TOOLS_PATH}/autograd/templates/variable_factories.h"
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py.in"
"${TOOLS_PATH}/autograd/deprecated.yaml"

View File

@ -613,6 +613,10 @@ Utilities
vmap
_assert
Operator Tags
------------------------------------
.. autoclass:: Tag
:members:
.. Empty submodules added only for tracking.
.. py:module:: torch.contrib

View File

@ -357,6 +357,7 @@ def get_aten_generated_files(enabled_backends):
"core/TensorBody.h",
"core/TensorMethods.cpp",
"core/aten_interned_strings.h",
"core/enum_tag.h",
] + get_aten_derived_type_srcs(enabled_backends)
# This is tiresome. A better strategy would be to unconditionally

View File

@ -65,6 +65,7 @@ import torch.testing._internal.opinfo_helper as opinfo_helper
from torch.testing._internal import composite_compliance
from torch.utils._pytree import tree_flatten
from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
torch.set_default_dtype(torch.float32)
@ -1432,6 +1433,57 @@ class TestMathBits(TestCase):
torch.is_complex,
)
# input strides and size may have been altered due to the result of an inplace op
def test_inplace_view(func, input, rs, input_size, input_strides):
if func is None:
return
# TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out
# which mutate not necessarily the first input.
if isinstance(rs, torch.Tensor) and rs is input:
unequal_size = rs.size() != input_size
unequal_strides = rs.stride() != input_strides
# resize_ should probably have inplace_view tag. Not adding the tag since it
# breaks some codegen logic
if (unequal_size or unequal_strides):
if isinstance(func, torch._ops.OpOverloadPacket):
func = func.default
# Reference: https://github.com/pytorch/pytorch/issues/78759
if func is not torch.ops.aten.resize_.default:
# TODO: use self.assertIn when we have separate tests for each tag
assert torch.Tag.inplace_view in func.tags
# A mode that when enabled runs correctness checks to ensure
# that operators have expected tags based on their input and
# ouput tensor properties
class TestTagsMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if isinstance(args[0], torch.Tensor):
old_size = args[0].size()
old_stride = args[0].stride()
rs = func(*args, **kwargs)
test_inplace_view(func, args[0], rs, old_size, old_stride)
else:
rs = func(*args, **kwargs)
return rs
# Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags`
class TestTags(TestCase):
@onlyCPU
@ops(ops_and_refs, dtypes=OpDTypes.any_one)
def test_tags(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=False)
for sample in samples:
# TODO: Test tags for ops that return a list of tensors
input = sample.input
if isinstance(input, torch.Tensor):
old_size = input.size()
old_stride = input.stride()
with push_torch_dispatch_mode(TestTagsMode):
rs = op(input, *sample.args, **sample.kwargs)
# TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
aten_name = op.aten_name if op.aten_name is not None else op.name
opoverloadpacket = getattr(torch.ops.aten, aten_name, None)
test_inplace_view(opoverloadpacket, input, rs, old_size, old_stride)
class TestRefsOpsInfo(TestCase):
@ -1574,8 +1626,7 @@ instantiate_device_type_tests(TestCompositeCompliance, globals())
instantiate_device_type_tests(TestMathBits, globals())
instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
instantiate_device_type_tests(TestFakeTensorNonErroring, globals())
instantiate_device_type_tests(TestTags, globals())
if __name__ == "__main__":
run_tests()

View File

@ -241,6 +241,11 @@ class TestPublicBindings(TestCase):
"vitals_enabled",
"wait",
"Tag",
"inplace_view",
"view_copy",
"generated",
"dynamic_output_shape",
}
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}

View File

@ -25,6 +25,7 @@ python_library(
"templates/python_torch_functions.cpp",
"templates/python_variable_methods.cpp",
"templates/variable_factories.h",
"templates/python_enum_tag.cpp",
],
visibility = ["PUBLIC"],
deps = [

View File

@ -57,7 +57,7 @@ from torchgen.api.python import (
namedtuple_fieldnames,
signature,
)
from torchgen.gen import cpp_string, parse_native_yaml
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
from torchgen.context import with_native_function
from torchgen.model import (
Argument,
@ -326,6 +326,17 @@ def gen(
fm, functions, lambda fn: True, "python_return_types.cpp"
)
valid_tags = parse_tags_yaml(tags_yaml_path)
def gen_tags_enum() -> Dict[str, str]:
return {
"enum_of_valid_tags": (
"".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags])
)
}
fm.write("python_enum_tag.cpp", gen_tags_enum)
def group_filter_overloads(
pairs: Sequence[PythonSignatureNativeFunctionPair],

View File

@ -0,0 +1,15 @@
#include <torch/csrc/autograd/python_enum_tag.h>
#include <pybind11/pybind11.h>
#include <ATen/core/enum_tag.h>
namespace py = pybind11;
namespace torch {
namespace autograd {
void initEnumTag(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::enum_<at::Tag>(m, "Tag")
${enum_of_valid_tags}
.export_values();
m.doc() = "An Enum that contains tags that can be assigned to an operator registered in C++.";
}
}}

View File

@ -202,7 +202,7 @@ def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ...
def _get_operation_overload(op_name: str, op_overload_name: str) -> Callable: ...
def _get_operation_overload(op_name: str, op_overload_name: str) -> Tuple[Callable, List[Any]]: ...
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
optimization_blocklist: Set[MobileOptimizerType],

View File

@ -28,10 +28,11 @@ def dl_open_guard():
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
class OpOverload:
def __init__(self, overloadpacket, op, schema):
def __init__(self, overloadpacket, op, schema, tags):
self._op = op
self._schema = schema
self._overloadpacket = overloadpacket
self._tags = tags
self._overloadname = 'default' if schema.overload_name == '' else schema.overload_name
self.__name__ = "{}.{}".format(self._schema.name.split("::")[1], self._overloadname)
self.__module__ = overloadpacket.__module__
@ -65,6 +66,10 @@ class OpOverload:
def op(self):
return self._op
@property
def tags(self):
return self._tags
# TODO: add more methods to expose information about input and output arguments
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
@ -123,10 +128,10 @@ class OpOverloadPacket:
# This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
use_key = '' if key == 'default' else key
# TODO: disallow access to overloads registered by JIT
op_ = torch._C._get_operation_overload(
op_, tags = torch._C._get_operation_overload(
self._qualified_op_name, use_key)
schema = torch._C._get_schema(self._qualified_op_name, use_key)
overload = OpOverload(self, op_, schema)
overload = OpOverload(self, op_, schema, tags)
# cache the overload object
setattr(self, key, overload)
return overload

View File

@ -41,6 +41,7 @@
#include <torch/csrc/autograd/python_sparse_functions.h>
#include <torch/csrc/autograd/python_special_functions.h>
#include <torch/csrc/autograd/python_return_types.h>
#include <torch/csrc/autograd/python_enum_tag.h>
#include <torch/csrc/autograd/python_legacy_variable.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/multiprocessing/init.h>
@ -828,6 +829,7 @@ PyObject* initModule() {
// the export side of JIT, so this ONNX init needs to appear before the JIT
// init.
torch::onnx::initONNXBindings(module);
torch::autograd::initEnumTag(module);
torch::jit::initJITBindings(module);
torch::monitor::initMonitorBindings(module);
torch::impl::dispatch::initDispatchBindings(module);

View File

@ -0,0 +1,8 @@
#pragma once
#include <torch/csrc/python_headers.h>
namespace torch {
namespace autograd {
void initEnumTag(PyObject* module);
}} // namespace torch::autograd

View File

@ -1363,7 +1363,7 @@ void initJITBindings(PyObject* module) {
return _get_operation_for_overload_or_packet(
{op}, symbol, args, kwargs, true);
});
return func;
return py::make_tuple(func, py::cast(op->getTags().vec()));
}
}
throw std::runtime_error("Found no matching operator overload");

View File

@ -160,6 +160,16 @@ struct TORCH_API Operator {
});
}
c10::ArrayRef<at::Tag> getTags() const {
return op_.fold<c10::ArrayRef<at::Tag>>(
[](const C10Operator& op) { return op.handle_.getTags(); },
[](const JitOnlyOperator& op) {
// Returns empty list of tags for JitOnlyOperators since it
// doesn't save c10::OperatorHandle
return c10::ArrayRef<at::Tag>();
});
}
bool isC10Op() const {
return op_.is_left();
}

View File

@ -65,6 +65,7 @@
// Just for inferFunctionSchemaFromFunctor
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/enum_tag.h>
namespace torch {
@ -594,12 +595,12 @@ class TORCH_API Library final {
/// m.def("add(Tensor self, Tensor other) -> Tensor");
/// }
/// ```
template <typename Schema>
Library& def(Schema&& raw_schema) & {
c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
return _def(std::move(s));
}
template <typename Schema>
Library& def(Schema&& raw_schema, const std::vector<at::Tag>& tags = {}) & {
c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
return _def(std::move(s), nullptr, tags);
}
/// Define an operator for a schema and then register an implementation for
/// it. This is typically what you would use if you aren't planning
/// on making use of the dispatcher to structure your operator
@ -813,7 +814,8 @@ class TORCH_API Library final {
// public because we only implement & qualifier and not && qualifier
Library& _def(
c10::FunctionSchema&& schema,
c10::OperatorName* out_name = nullptr) &;
c10::OperatorName* out_name = nullptr,
const std::vector<at::Tag>& tags = {}) &;
Library& _def(
c10::either<c10::OperatorName, c10::FunctionSchema>&&,
CppFunction&& f) &;

View File

@ -11965,6 +11965,7 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
# RuntimeError: Sparse CSR tensors do not have strides.
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
# RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),
# RuntimeError: Sparse CSR tensors do not have strides
@ -14123,8 +14124,10 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cpu'),
DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo',
'test_nnc_correctness', dtypes=(torch.bfloat16,)),
DecorateInfo(unittest.skip("Works on some conifgs"), 'TestCudaFuserOpInfo',
'test_nvfuser_correctness', dtypes=(torch.bfloat16,)),
# RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet.
# Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data()
# to actually allocate memory
DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
),
sample_inputs_func=sample_inputs_max_pool),
OpInfo('nn.functional.max_pool2d',
@ -17626,6 +17629,7 @@ op_db: List[OpInfo] = [
# Allowed exception: sparse tensors don't have strides
DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'),
DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_backward'),
DecorateInfo(unittest.skip("Allowed exception"), 'TestTags', 'test_tags'),
# TODO: implement csr.to_sparse(sample_dim) where sampled_dim is 1.
DecorateInfo(unittest.skip("csr.to_sparse(1) not implemented. Skipped!"),
'TestSparseCSR', 'test_sparse_csr_consistency'),

View File

@ -152,6 +152,7 @@ class LineLoader(YamlLoader):
_GLOBAL_PARSE_NATIVE_YAML_CACHE = {}
_GLOBAL_PARSE_TAGS_YAML_CACHE = {}
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
@ -220,11 +221,13 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
@functools.lru_cache(maxsize=None)
def parse_tags_yaml(path: str) -> Set[str]:
# TODO: parse tags.yaml and create a tags database (a dict of tag name mapping to a Tag object)
with open(path, "r") as f:
es = yaml.load(f, Loader=LineLoader)
valid_tags = parse_tags_yaml_struct(es, path=path)
return valid_tags
global _GLOBAL_PARSE_TAGS_YAML_CACHE
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
with open(path, "r") as f:
es = yaml.load(f, Loader=LineLoader)
_GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
def parse_native_yaml(
@ -234,7 +237,6 @@ def parse_native_yaml(
*,
skip_native_fns_gen: bool = False,
) -> ParsedYaml:
# TODO: parse tags.yaml and create a tags database (a dict of tag name mapping to a Tag object)
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
valid_tags = parse_tags_yaml(tags_yaml_path)
@ -500,7 +502,8 @@ class RegisterSchema:
def __call__(self, f: NativeFunction) -> Optional[str]:
if not self.selector.is_native_function_selected(f):
return None
return f"m.def({cpp_string(str(f.func))});\n"
tags = "{" + ", ".join([f"at::Tag::{tag}" for tag in f.tags]) + "}"
return f"m.def({cpp_string(str(f.func))}, {tags});\n"
# Generates Operators.h and Operators.cpp.
@ -1711,6 +1714,7 @@ def gen_per_operator_headers(
def gen_headers(
*,
native_functions: Sequence[NativeFunction],
valid_tags: Set[str],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
static_dispatch_idx: List[BackendIndex],
@ -1838,6 +1842,11 @@ def gen_headers(
core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
def gen_tags_enum() -> Dict[str, str]:
return {"enum_of_valid_tags": (",\n".join([f"{tag}" for tag in valid_tags]))}
core_fm.write("enum_tag.h", gen_tags_enum)
def gen_source_files(
*,
@ -2441,6 +2450,7 @@ def main() -> None:
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
native_functions, backend_indices = (
parsed_yaml.native_functions,
parsed_yaml.backend_indices,
@ -2546,6 +2556,7 @@ def main() -> None:
if "headers" in options.generate:
gen_headers(
native_functions=native_functions,
valid_tags=valid_tags,
grouped_native_functions=grouped_native_functions,
structured_native_functions=structured_native_functions,
static_dispatch_idx=static_dispatch_idx,