dispatch API for checking computed table, use it in prim decomps (#82358)

Fixes https://github.com/pytorch/pytorch/issues/82331

Expose a `torch._C._dispatch_has_computed_kernel_for_dispatch_key` to check if an operator has a kernel registered to the given dispatch key in the **computed table**.

Use it in the prim registration logic, making it more accurate and robust (so that it e.g. picks up `CompositeExplicitAutograd` kernels.

It looks like before this change we'd register 134 prim ops to the meta key, and after we only register 62. So that's 72 ops that now use an existing C++ decomp to get meta working, instead of going directly through the prim decomp.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82358
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2022-08-10 10:44:06 -07:00
committed by PyTorch MergeBot
parent 8a6b076196
commit 1a51efd8bb
8 changed files with 64 additions and 17 deletions

View File

@ -333,6 +333,9 @@ public:
return operatorDef_->op.hasKernelForDispatchKey(k);
}
bool hasComputedKernelForDispatchKey(DispatchKey k) const {
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
}
std::string dumpComputedTable() const {
return operatorDef_->op.dumpComputedTable();

View File

@ -211,6 +211,13 @@ const KernelFunction& OperatorEntry::kernelForDispatchKey(DispatchKey k) const {
return jt->kernel;
}
bool OperatorEntry::hasComputedKernelForDispatchKey(DispatchKey k) const {
TORCH_CHECK(!isAliasDispatchKey(k), "Alias keys do not have runtime kernel registrations.");
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
TORCH_INTERNAL_ASSERT(dispatch_ix >= 0 && dispatch_ix < c10::num_runtime_entries, toString(k), dispatch_ix);
return dispatchTable_[dispatch_ix].isValid();
}
const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
auto kern_it = kernels_.find(dispatch_key);
if (kern_it != kernels_.end()) {

View File

@ -210,6 +210,8 @@ public:
// hasKernelForDispatchKey. To get the AnnotatedKernel, see
// getKernelForDispatchKey (private)
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
// Returns true if the "computed table" has an entry for a particular key.
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
// Returns all the operator tags added at the time of registration
const std::vector<at::Tag>& getTags() const;

View File

@ -1108,24 +1108,26 @@ Tensor math_addr(const Tensor& self,
const Scalar& beta, const Scalar& alpha) {
// when beta==0, values in self should be ignored,
// nans and infs in self should not propagate.
Tensor out;
if (beta.toComplexDouble() == 0.0) {
if (alpha.toComplexDouble() == 1.0) {
return at::outer(vec1, vec2);
out = at::outer(vec1, vec2);
} else {
out = alpha * at::outer(vec1, vec2);
}
return alpha * at::outer(vec1, vec2);
}
if (beta.toComplexDouble() == 1.0) {
} else if (beta.toComplexDouble() == 1.0) {
if (alpha.toComplexDouble() == 1.0) {
return self + at::outer(vec1, vec2);
out = self + at::outer(vec1, vec2);
} else {
out = self + alpha * at::outer(vec1, vec2);
}
return self + alpha * at::outer(vec1, vec2);
} else if (alpha.toComplexDouble() == 1.0) {
out = beta * self + at::outer(vec1, vec2);
} else {
out = beta * self + alpha * at::outer(vec1, vec2);
}
if (alpha.toComplexDouble() == 1.0) {
return beta * self + at::outer(vec1, vec2);
}
return beta * self + alpha * at::outer(vec1, vec2);
auto result_type = c10::promoteTypes(c10::promoteTypes(self.scalar_type(), vec1.scalar_type()), vec2.scalar_type());
return out.to(c10::TensorOptions().dtype(result_type));
}
Tensor& math_addr_out(const Tensor& self,

View File

@ -206,6 +206,16 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
const int normalized_ndim = normalized_shape.size();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int axis = input_ndim - normalized_ndim;
// Properly handle zero-size inputs: the view(1, M, -1) call below breaks on this.
if (input.numel() == 0) {
auto result_type = c10::promoteTypes(input.scalar_type(), kFloat);
return std::make_tuple(
at::empty_like(input),
at::empty_like(input, c10::TensorOptions().dtype(result_type)),
at::empty_like(input, c10::TensorOptions().dtype(result_type))
);
}
at::Tensor input_reshaped = input.view({1, M, -1});
// Unlike Batch Normalization, which applies scalar scale and bias for each
// entire channel/plane with the affine option, Layer Normalization applies

View File

@ -965,6 +965,7 @@ class Generator(object):
# Defined in torch/csrc/utils/python_dispatch.cpp
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
def _dispatch_has_kernel(name: str) -> _bool: ...
def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ...
def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ...

View File

@ -110,11 +110,15 @@ def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False
# which don't have corresponding dispatcher entries, we need
# to filter those out
and torch._C._dispatch_has_kernel(name)
# Don't register a meta kernel to any operator that has
# a CompositeImplicitAutograd kernel in core.
# Otherwise we won't be able to run autograd for that operator with the meta backend.
and "CompositeImplicitAutograd" not in torch._C._dispatch_dump(name)
and not torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta")
# Don't register a python meta kernel to any operator that has
# should already work with meta tensors today.
# We can check that by seeing if the "computed table" for the operator
# has a registration to Meta;
# either through a direct registration, or an indirect one through
# an alias dispatch key (e.g. CompositeImplicitAutograd)
and not torch._C._dispatch_has_computed_kernel_for_dispatch_key(
name, "Meta"
)
):
if any(
a.alias_info is not None and not a.alias_info.is_write

View File

@ -292,6 +292,8 @@ void initDispatchBindings(PyObject* module) {
});
m.def(
// Returns whether or not a direct kernel registration exists
// for this <op_name, dispatch_key> pair.
"_dispatch_has_kernel_for_dispatch_key",
[](const char* name, const char* dispatch) -> bool {
auto op =
@ -300,6 +302,22 @@ void initDispatchBindings(PyObject* module) {
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
});
m.def(
// Returns whether or not there is an entry in the runtime computed
// dispatch table, for this <op_name, dispatch_key> pair. For example, if
// "op" has a `CompositeImplicitAutograd` kernel, Then
// _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
// true for all backends that are part of the alias set for
// CompositeImplicitAutograd.
"_dispatch_has_computed_kernel_for_dispatch_key",
[](const char* name, const char* dispatch) -> bool {
auto op =
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
TORCH_CHECK(op, "operator ", name, " does not exist");
return op->hasComputedKernelForDispatchKey(
c10::parseDispatchKey(dispatch));
});
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();