mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[RFC][Codegen] Add custom namespace support (#78015)
Summary: Adding a feature to allow user to specify namespaces for operator and kernels. # Feature There's a feature request to allow DSL to: 1. take in an operator namespace other than `aten`. 2. take in a kernel that is in a different namespace than `at::native`. For both features, we only allow user to have a single layer of namespace for the sake of simplicity. If user specify `custom::function` as kernel, the codegen will depend on `custom::native::function` where `native` is hardcoded. # Proposal For feature 1, add a `namespace` attribute to data class `NativeFunction`. The namespace will be extract out by matching pattern "::" on the `func` variable. For `NativeFunctionsGroup` there's an assumption that all variants (function, inplace, out) will have the same namespace. By default (if not specified) the namespace will be "aten". For feature 2, add a `namespace` attribute to `BackendMetadata` class, similarly match pattern "::" on the kernel field. Remove the `cpp_namespace` field from `register_dispatch_key` data class. By default (if not specified) the namespace for a kernel would be "at::native". Test Plan: Example yaml entries: ``` - func: custom::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator python_module: nn dispatch: CPU: custom::gelu_out_cpu CUDA: custom::gelu_out_cuda MPS: custom::gelu_out_mps - func: custom::gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!) structured_delegate: gelu.out device_check: NoCheck # TensorIterator python_module: nn dispatch: NestedTensorCPU, NestedTensorCUDA: custom::NestedTensor_gelu_ - func: custom::gelu(Tensor self, *, str approximate='none') -> Tensor structured_delegate: gelu.out device_check: NoCheck # TensorIterator python_module: nn dispatch: MkldnnCPU: custom::mkldnn_gelu QuantizedCPU: custom::gelu_quantized_cpu NestedTensorCPU, NestedTensorCUDA: custom::NestedTensor_gelu ``` see generated code: `RegisterCPU.cpp`: ``` TORCH_LIBRARY_IMPL(aten, CPU, m) { ... } TORCH_LIBRARY_IMPL(custom, CPU, m) { m.impl("gelu", TORCH_FN(wrapper_gelu)); m.impl("gelu.out", TORCH_FN(wrapper_gelu_out_out)); m.impl("gelu_", TORCH_FN(wrapper_gelu_)); }; ``` ``` struct structured_gelu_out_cpu_inplace final : public custom::native::structured_gelu_out_cpu { structured_gelu_out_cpu_inplace(Tensor& self) : outputs_{std::ref(self)} {} void set_output_strided( int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names ) override { const auto& out = outputs_[output_idx].get(); check_inplace(out, sizes, options); auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options); if (C10_UNLIKELY(maybe_proxy.has_value())) { proxy_outputs_[output_idx] = c10::ExclusivelyOwned<Tensor>(std::move(maybe_proxy).value()); } if (!names.empty()) { namedinference::propagate_names(outputs_[output_idx], names); } // super must happen after, so that downstream can use maybe_get_output // to retrieve the output custom::native::structured_gelu_out_cpu::set_output_raw_strided(output_idx, sizes, strides, options, names); } void set_output_raw_strided( int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names ) override { const auto& out = outputs_[output_idx].get(); check_inplace(out, sizes, options); if (!names.empty()) { namedinference::propagate_names(outputs_[output_idx], names); } // super must happen after, so that downstream can use maybe_get_output // to retrieve the output custom::native::structured_gelu_out_cpu::set_output_raw_strided(output_idx, sizes, strides, options, names); } const Tensor& maybe_get_output(int64_t output_idx) override { return proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get(); } std::array<std::reference_wrapper<Tensor>, 1> outputs_; std::array<c10::optional<c10::ExclusivelyOwned<Tensor>>, 1> proxy_outputs_; }; ``` `RegisterSchema.cpp` ``` TORCH_LIBRARY(aten, m) { ... } TORCH_LIBRARY(custom, m) { m.def("gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)"); m.def("gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)"); m.def("gelu(Tensor self, *, str approximate='none') -> Tensor"); }; ``` Differential Revision: D36558459 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78015 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
d28e9e145b
commit
24050a5801
@ -123,7 +123,9 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
||||
# See Note [External Backends Follow Dispatcher API]
|
||||
kernel_name = dispatcher.name(native_functions_map[op_name].func)
|
||||
# TODO: allow structured external backends later.
|
||||
m = BackendMetadata(kernel=kernel_name, structured=False)
|
||||
m = BackendMetadata(
|
||||
kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace
|
||||
)
|
||||
metadata[op_name] = m
|
||||
return BackendIndex(
|
||||
dispatch_key=dispatch_key,
|
||||
@ -373,7 +375,6 @@ def gen_dispatcher_registrations(
|
||||
fm: FileManager,
|
||||
output_dir: str,
|
||||
class_name: str,
|
||||
cpp_namespace: str,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
backend_dispatch_key: DispatchKey,
|
||||
@ -403,7 +404,6 @@ def gen_dispatcher_registrations(
|
||||
Target.REGISTRATION,
|
||||
selector,
|
||||
rocm=False,
|
||||
cpp_namespace=cpp_namespace,
|
||||
class_method_name=f"{class_name}",
|
||||
skip_dispatcher_op_registration=False,
|
||||
),
|
||||
@ -462,7 +462,6 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
|
||||
Target.ANONYMOUS_DEFINITION,
|
||||
selector,
|
||||
rocm=False,
|
||||
cpp_namespace=cpp_namespace,
|
||||
class_method_name=f"{class_name}",
|
||||
skip_dispatcher_op_registration=False,
|
||||
),
|
||||
@ -548,7 +547,6 @@ def run(
|
||||
fm,
|
||||
output_dir,
|
||||
class_name,
|
||||
cpp_namespace,
|
||||
backend_indices,
|
||||
grouped_native_functions,
|
||||
backend_key,
|
||||
|
Reference in New Issue
Block a user