dispatch API for checking computed table, use it in prim decomps (#82358)

Fixes https://github.com/pytorch/pytorch/issues/82331

Expose a `torch._C._dispatch_has_computed_kernel_for_dispatch_key` to check if an operator has a kernel registered to the given dispatch key in the **computed table**.

Use it in the prim registration logic, making it more accurate and robust (so that it e.g. picks up `CompositeExplicitAutograd` kernels.

It looks like before this change we'd register 134 prim ops to the meta key, and after we only register 62. So that's 72 ops that now use an existing C++ decomp to get meta working, instead of going directly through the prim decomp.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82358
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2022-08-10 10:44:06 -07:00
committed by PyTorch MergeBot
parent 8a6b076196
commit 1a51efd8bb
8 changed files with 64 additions and 17 deletions

View File

@ -333,6 +333,9 @@ public:
return operatorDef_->op.hasKernelForDispatchKey(k); return operatorDef_->op.hasKernelForDispatchKey(k);
} }
bool hasComputedKernelForDispatchKey(DispatchKey k) const {
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
}
std::string dumpComputedTable() const { std::string dumpComputedTable() const {
return operatorDef_->op.dumpComputedTable(); return operatorDef_->op.dumpComputedTable();

View File

@ -211,6 +211,13 @@ const KernelFunction& OperatorEntry::kernelForDispatchKey(DispatchKey k) const {
return jt->kernel; return jt->kernel;
} }
bool OperatorEntry::hasComputedKernelForDispatchKey(DispatchKey k) const {
TORCH_CHECK(!isAliasDispatchKey(k), "Alias keys do not have runtime kernel registrations.");
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
TORCH_INTERNAL_ASSERT(dispatch_ix >= 0 && dispatch_ix < c10::num_runtime_entries, toString(k), dispatch_ix);
return dispatchTable_[dispatch_ix].isValid();
}
const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
auto kern_it = kernels_.find(dispatch_key); auto kern_it = kernels_.find(dispatch_key);
if (kern_it != kernels_.end()) { if (kern_it != kernels_.end()) {

View File

@ -210,6 +210,8 @@ public:
// hasKernelForDispatchKey. To get the AnnotatedKernel, see // hasKernelForDispatchKey. To get the AnnotatedKernel, see
// getKernelForDispatchKey (private) // getKernelForDispatchKey (private)
const KernelFunction& kernelForDispatchKey(DispatchKey k) const; const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
// Returns true if the "computed table" has an entry for a particular key.
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
// Returns all the operator tags added at the time of registration // Returns all the operator tags added at the time of registration
const std::vector<at::Tag>& getTags() const; const std::vector<at::Tag>& getTags() const;

View File

@ -1108,24 +1108,26 @@ Tensor math_addr(const Tensor& self,
const Scalar& beta, const Scalar& alpha) { const Scalar& beta, const Scalar& alpha) {
// when beta==0, values in self should be ignored, // when beta==0, values in self should be ignored,
// nans and infs in self should not propagate. // nans and infs in self should not propagate.
Tensor out;
if (beta.toComplexDouble() == 0.0) { if (beta.toComplexDouble() == 0.0) {
if (alpha.toComplexDouble() == 1.0) { if (alpha.toComplexDouble() == 1.0) {
return at::outer(vec1, vec2); out = at::outer(vec1, vec2);
} else {
out = alpha * at::outer(vec1, vec2);
} }
return alpha * at::outer(vec1, vec2); } else if (beta.toComplexDouble() == 1.0) {
}
if (beta.toComplexDouble() == 1.0) {
if (alpha.toComplexDouble() == 1.0) { if (alpha.toComplexDouble() == 1.0) {
return self + at::outer(vec1, vec2); out = self + at::outer(vec1, vec2);
} else {
out = self + alpha * at::outer(vec1, vec2);
} }
return self + alpha * at::outer(vec1, vec2); } else if (alpha.toComplexDouble() == 1.0) {
out = beta * self + at::outer(vec1, vec2);
} else {
out = beta * self + alpha * at::outer(vec1, vec2);
} }
auto result_type = c10::promoteTypes(c10::promoteTypes(self.scalar_type(), vec1.scalar_type()), vec2.scalar_type());
if (alpha.toComplexDouble() == 1.0) { return out.to(c10::TensorOptions().dtype(result_type));
return beta * self + at::outer(vec1, vec2);
}
return beta * self + alpha * at::outer(vec1, vec2);
} }
Tensor& math_addr_out(const Tensor& self, Tensor& math_addr_out(const Tensor& self,

View File

@ -206,6 +206,16 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
const int normalized_ndim = normalized_shape.size(); const int normalized_ndim = normalized_shape.size();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int axis = input_ndim - normalized_ndim; const int axis = input_ndim - normalized_ndim;
// Properly handle zero-size inputs: the view(1, M, -1) call below breaks on this.
if (input.numel() == 0) {
auto result_type = c10::promoteTypes(input.scalar_type(), kFloat);
return std::make_tuple(
at::empty_like(input),
at::empty_like(input, c10::TensorOptions().dtype(result_type)),
at::empty_like(input, c10::TensorOptions().dtype(result_type))
);
}
at::Tensor input_reshaped = input.view({1, M, -1}); at::Tensor input_reshaped = input.view({1, M, -1});
// Unlike Batch Normalization, which applies scalar scale and bias for each // Unlike Batch Normalization, which applies scalar scale and bias for each
// entire channel/plane with the affine option, Layer Normalization applies // entire channel/plane with the affine option, Layer Normalization applies

View File

@ -965,6 +965,7 @@ class Generator(object):
# Defined in torch/csrc/utils/python_dispatch.cpp # Defined in torch/csrc/utils/python_dispatch.cpp
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ... def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ... def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
def _dispatch_has_kernel(name: str) -> _bool: ... def _dispatch_has_kernel(name: str) -> _bool: ...
def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ... def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ...
def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ... def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ...

View File

@ -110,11 +110,15 @@ def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False
# which don't have corresponding dispatcher entries, we need # which don't have corresponding dispatcher entries, we need
# to filter those out # to filter those out
and torch._C._dispatch_has_kernel(name) and torch._C._dispatch_has_kernel(name)
# Don't register a meta kernel to any operator that has # Don't register a python meta kernel to any operator that has
# a CompositeImplicitAutograd kernel in core. # should already work with meta tensors today.
# Otherwise we won't be able to run autograd for that operator with the meta backend. # We can check that by seeing if the "computed table" for the operator
and "CompositeImplicitAutograd" not in torch._C._dispatch_dump(name) # has a registration to Meta;
and not torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta") # either through a direct registration, or an indirect one through
# an alias dispatch key (e.g. CompositeImplicitAutograd)
and not torch._C._dispatch_has_computed_kernel_for_dispatch_key(
name, "Meta"
)
): ):
if any( if any(
a.alias_info is not None and not a.alias_info.is_write a.alias_info is not None and not a.alias_info.is_write

View File

@ -292,6 +292,8 @@ void initDispatchBindings(PyObject* module) {
}); });
m.def( m.def(
// Returns whether or not a direct kernel registration exists
// for this <op_name, dispatch_key> pair.
"_dispatch_has_kernel_for_dispatch_key", "_dispatch_has_kernel_for_dispatch_key",
[](const char* name, const char* dispatch) -> bool { [](const char* name, const char* dispatch) -> bool {
auto op = auto op =
@ -300,6 +302,22 @@ void initDispatchBindings(PyObject* module) {
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch)); return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
}); });
m.def(
// Returns whether or not there is an entry in the runtime computed
// dispatch table, for this <op_name, dispatch_key> pair. For example, if
// "op" has a `CompositeImplicitAutograd` kernel, Then
// _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
// true for all backends that are part of the alias set for
// CompositeImplicitAutograd.
"_dispatch_has_computed_kernel_for_dispatch_key",
[](const char* name, const char* dispatch) -> bool {
auto op =
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
TORCH_CHECK(op, "operator ", name, " does not exist");
return op->hasComputedKernelForDispatchKey(
c10::parseDispatchKey(dispatch));
});
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> { m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls(); auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();