mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
dispatch API for checking computed table, use it in prim decomps (#82358)
Fixes https://github.com/pytorch/pytorch/issues/82331 Expose a `torch._C._dispatch_has_computed_kernel_for_dispatch_key` to check if an operator has a kernel registered to the given dispatch key in the **computed table**. Use it in the prim registration logic, making it more accurate and robust (so that it e.g. picks up `CompositeExplicitAutograd` kernels. It looks like before this change we'd register 134 prim ops to the meta key, and after we only register 62. So that's 72 ops that now use an existing C++ decomp to get meta working, instead of going directly through the prim decomp. Pull Request resolved: https://github.com/pytorch/pytorch/pull/82358 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
8a6b076196
commit
1a51efd8bb
@ -333,6 +333,9 @@ public:
|
||||
return operatorDef_->op.hasKernelForDispatchKey(k);
|
||||
}
|
||||
|
||||
bool hasComputedKernelForDispatchKey(DispatchKey k) const {
|
||||
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
|
||||
}
|
||||
|
||||
std::string dumpComputedTable() const {
|
||||
return operatorDef_->op.dumpComputedTable();
|
||||
|
@ -211,6 +211,13 @@ const KernelFunction& OperatorEntry::kernelForDispatchKey(DispatchKey k) const {
|
||||
return jt->kernel;
|
||||
}
|
||||
|
||||
bool OperatorEntry::hasComputedKernelForDispatchKey(DispatchKey k) const {
|
||||
TORCH_CHECK(!isAliasDispatchKey(k), "Alias keys do not have runtime kernel registrations.");
|
||||
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
|
||||
TORCH_INTERNAL_ASSERT(dispatch_ix >= 0 && dispatch_ix < c10::num_runtime_entries, toString(k), dispatch_ix);
|
||||
return dispatchTable_[dispatch_ix].isValid();
|
||||
}
|
||||
|
||||
const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
|
||||
auto kern_it = kernels_.find(dispatch_key);
|
||||
if (kern_it != kernels_.end()) {
|
||||
|
@ -210,6 +210,8 @@ public:
|
||||
// hasKernelForDispatchKey. To get the AnnotatedKernel, see
|
||||
// getKernelForDispatchKey (private)
|
||||
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
|
||||
// Returns true if the "computed table" has an entry for a particular key.
|
||||
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
|
||||
// Returns all the operator tags added at the time of registration
|
||||
const std::vector<at::Tag>& getTags() const;
|
||||
|
||||
|
@ -1108,24 +1108,26 @@ Tensor math_addr(const Tensor& self,
|
||||
const Scalar& beta, const Scalar& alpha) {
|
||||
// when beta==0, values in self should be ignored,
|
||||
// nans and infs in self should not propagate.
|
||||
Tensor out;
|
||||
if (beta.toComplexDouble() == 0.0) {
|
||||
if (alpha.toComplexDouble() == 1.0) {
|
||||
return at::outer(vec1, vec2);
|
||||
out = at::outer(vec1, vec2);
|
||||
} else {
|
||||
out = alpha * at::outer(vec1, vec2);
|
||||
}
|
||||
return alpha * at::outer(vec1, vec2);
|
||||
}
|
||||
|
||||
if (beta.toComplexDouble() == 1.0) {
|
||||
} else if (beta.toComplexDouble() == 1.0) {
|
||||
if (alpha.toComplexDouble() == 1.0) {
|
||||
return self + at::outer(vec1, vec2);
|
||||
out = self + at::outer(vec1, vec2);
|
||||
} else {
|
||||
out = self + alpha * at::outer(vec1, vec2);
|
||||
}
|
||||
return self + alpha * at::outer(vec1, vec2);
|
||||
} else if (alpha.toComplexDouble() == 1.0) {
|
||||
out = beta * self + at::outer(vec1, vec2);
|
||||
} else {
|
||||
out = beta * self + alpha * at::outer(vec1, vec2);
|
||||
}
|
||||
|
||||
if (alpha.toComplexDouble() == 1.0) {
|
||||
return beta * self + at::outer(vec1, vec2);
|
||||
}
|
||||
return beta * self + alpha * at::outer(vec1, vec2);
|
||||
auto result_type = c10::promoteTypes(c10::promoteTypes(self.scalar_type(), vec1.scalar_type()), vec2.scalar_type());
|
||||
return out.to(c10::TensorOptions().dtype(result_type));
|
||||
}
|
||||
|
||||
Tensor& math_addr_out(const Tensor& self,
|
||||
|
@ -206,6 +206,16 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
|
||||
const int normalized_ndim = normalized_shape.size();
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
const int axis = input_ndim - normalized_ndim;
|
||||
|
||||
// Properly handle zero-size inputs: the view(1, M, -1) call below breaks on this.
|
||||
if (input.numel() == 0) {
|
||||
auto result_type = c10::promoteTypes(input.scalar_type(), kFloat);
|
||||
return std::make_tuple(
|
||||
at::empty_like(input),
|
||||
at::empty_like(input, c10::TensorOptions().dtype(result_type)),
|
||||
at::empty_like(input, c10::TensorOptions().dtype(result_type))
|
||||
);
|
||||
}
|
||||
at::Tensor input_reshaped = input.view({1, M, -1});
|
||||
// Unlike Batch Normalization, which applies scalar scale and bias for each
|
||||
// entire channel/plane with the affine option, Layer Normalization applies
|
||||
|
@ -965,6 +965,7 @@ class Generator(object):
|
||||
# Defined in torch/csrc/utils/python_dispatch.cpp
|
||||
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
|
||||
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
|
||||
def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
|
||||
def _dispatch_has_kernel(name: str) -> _bool: ...
|
||||
def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ...
|
||||
def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ...
|
||||
|
@ -110,11 +110,15 @@ def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False
|
||||
# which don't have corresponding dispatcher entries, we need
|
||||
# to filter those out
|
||||
and torch._C._dispatch_has_kernel(name)
|
||||
# Don't register a meta kernel to any operator that has
|
||||
# a CompositeImplicitAutograd kernel in core.
|
||||
# Otherwise we won't be able to run autograd for that operator with the meta backend.
|
||||
and "CompositeImplicitAutograd" not in torch._C._dispatch_dump(name)
|
||||
and not torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta")
|
||||
# Don't register a python meta kernel to any operator that has
|
||||
# should already work with meta tensors today.
|
||||
# We can check that by seeing if the "computed table" for the operator
|
||||
# has a registration to Meta;
|
||||
# either through a direct registration, or an indirect one through
|
||||
# an alias dispatch key (e.g. CompositeImplicitAutograd)
|
||||
and not torch._C._dispatch_has_computed_kernel_for_dispatch_key(
|
||||
name, "Meta"
|
||||
)
|
||||
):
|
||||
if any(
|
||||
a.alias_info is not None and not a.alias_info.is_write
|
||||
|
@ -292,6 +292,8 @@ void initDispatchBindings(PyObject* module) {
|
||||
});
|
||||
|
||||
m.def(
|
||||
// Returns whether or not a direct kernel registration exists
|
||||
// for this <op_name, dispatch_key> pair.
|
||||
"_dispatch_has_kernel_for_dispatch_key",
|
||||
[](const char* name, const char* dispatch) -> bool {
|
||||
auto op =
|
||||
@ -300,6 +302,22 @@ void initDispatchBindings(PyObject* module) {
|
||||
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
|
||||
});
|
||||
|
||||
m.def(
|
||||
// Returns whether or not there is an entry in the runtime computed
|
||||
// dispatch table, for this <op_name, dispatch_key> pair. For example, if
|
||||
// "op" has a `CompositeImplicitAutograd` kernel, Then
|
||||
// _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
|
||||
// true for all backends that are part of the alias set for
|
||||
// CompositeImplicitAutograd.
|
||||
"_dispatch_has_computed_kernel_for_dispatch_key",
|
||||
[](const char* name, const char* dispatch) -> bool {
|
||||
auto op =
|
||||
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
||||
TORCH_CHECK(op, "operator ", name, " does not exist");
|
||||
return op->hasComputedKernelForDispatchKey(
|
||||
c10::parseDispatchKey(dispatch));
|
||||
});
|
||||
|
||||
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
|
||||
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
|
||||
|
||||
|
Reference in New Issue
Block a user