mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
dispatch API for checking computed table, use it in prim decomps (#82358)
Fixes https://github.com/pytorch/pytorch/issues/82331 Expose a `torch._C._dispatch_has_computed_kernel_for_dispatch_key` to check if an operator has a kernel registered to the given dispatch key in the **computed table**. Use it in the prim registration logic, making it more accurate and robust (so that it e.g. picks up `CompositeExplicitAutograd` kernels. It looks like before this change we'd register 134 prim ops to the meta key, and after we only register 62. So that's 72 ops that now use an existing C++ decomp to get meta working, instead of going directly through the prim decomp. Pull Request resolved: https://github.com/pytorch/pytorch/pull/82358 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
8a6b076196
commit
1a51efd8bb
@ -333,6 +333,9 @@ public:
|
|||||||
return operatorDef_->op.hasKernelForDispatchKey(k);
|
return operatorDef_->op.hasKernelForDispatchKey(k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasComputedKernelForDispatchKey(DispatchKey k) const {
|
||||||
|
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
|
||||||
|
}
|
||||||
|
|
||||||
std::string dumpComputedTable() const {
|
std::string dumpComputedTable() const {
|
||||||
return operatorDef_->op.dumpComputedTable();
|
return operatorDef_->op.dumpComputedTable();
|
||||||
|
|||||||
@ -211,6 +211,13 @@ const KernelFunction& OperatorEntry::kernelForDispatchKey(DispatchKey k) const {
|
|||||||
return jt->kernel;
|
return jt->kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool OperatorEntry::hasComputedKernelForDispatchKey(DispatchKey k) const {
|
||||||
|
TORCH_CHECK(!isAliasDispatchKey(k), "Alias keys do not have runtime kernel registrations.");
|
||||||
|
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
|
||||||
|
TORCH_INTERNAL_ASSERT(dispatch_ix >= 0 && dispatch_ix < c10::num_runtime_entries, toString(k), dispatch_ix);
|
||||||
|
return dispatchTable_[dispatch_ix].isValid();
|
||||||
|
}
|
||||||
|
|
||||||
const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
|
const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispatch_key) const{
|
||||||
auto kern_it = kernels_.find(dispatch_key);
|
auto kern_it = kernels_.find(dispatch_key);
|
||||||
if (kern_it != kernels_.end()) {
|
if (kern_it != kernels_.end()) {
|
||||||
|
|||||||
@ -210,6 +210,8 @@ public:
|
|||||||
// hasKernelForDispatchKey. To get the AnnotatedKernel, see
|
// hasKernelForDispatchKey. To get the AnnotatedKernel, see
|
||||||
// getKernelForDispatchKey (private)
|
// getKernelForDispatchKey (private)
|
||||||
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
|
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
|
||||||
|
// Returns true if the "computed table" has an entry for a particular key.
|
||||||
|
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
|
||||||
// Returns all the operator tags added at the time of registration
|
// Returns all the operator tags added at the time of registration
|
||||||
const std::vector<at::Tag>& getTags() const;
|
const std::vector<at::Tag>& getTags() const;
|
||||||
|
|
||||||
|
|||||||
@ -1108,24 +1108,26 @@ Tensor math_addr(const Tensor& self,
|
|||||||
const Scalar& beta, const Scalar& alpha) {
|
const Scalar& beta, const Scalar& alpha) {
|
||||||
// when beta==0, values in self should be ignored,
|
// when beta==0, values in self should be ignored,
|
||||||
// nans and infs in self should not propagate.
|
// nans and infs in self should not propagate.
|
||||||
|
Tensor out;
|
||||||
if (beta.toComplexDouble() == 0.0) {
|
if (beta.toComplexDouble() == 0.0) {
|
||||||
if (alpha.toComplexDouble() == 1.0) {
|
if (alpha.toComplexDouble() == 1.0) {
|
||||||
return at::outer(vec1, vec2);
|
out = at::outer(vec1, vec2);
|
||||||
|
} else {
|
||||||
|
out = alpha * at::outer(vec1, vec2);
|
||||||
}
|
}
|
||||||
return alpha * at::outer(vec1, vec2);
|
} else if (beta.toComplexDouble() == 1.0) {
|
||||||
}
|
|
||||||
|
|
||||||
if (beta.toComplexDouble() == 1.0) {
|
|
||||||
if (alpha.toComplexDouble() == 1.0) {
|
if (alpha.toComplexDouble() == 1.0) {
|
||||||
return self + at::outer(vec1, vec2);
|
out = self + at::outer(vec1, vec2);
|
||||||
|
} else {
|
||||||
|
out = self + alpha * at::outer(vec1, vec2);
|
||||||
}
|
}
|
||||||
return self + alpha * at::outer(vec1, vec2);
|
} else if (alpha.toComplexDouble() == 1.0) {
|
||||||
|
out = beta * self + at::outer(vec1, vec2);
|
||||||
|
} else {
|
||||||
|
out = beta * self + alpha * at::outer(vec1, vec2);
|
||||||
}
|
}
|
||||||
|
auto result_type = c10::promoteTypes(c10::promoteTypes(self.scalar_type(), vec1.scalar_type()), vec2.scalar_type());
|
||||||
if (alpha.toComplexDouble() == 1.0) {
|
return out.to(c10::TensorOptions().dtype(result_type));
|
||||||
return beta * self + at::outer(vec1, vec2);
|
|
||||||
}
|
|
||||||
return beta * self + alpha * at::outer(vec1, vec2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& math_addr_out(const Tensor& self,
|
Tensor& math_addr_out(const Tensor& self,
|
||||||
|
|||||||
@ -206,6 +206,16 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
|
|||||||
const int normalized_ndim = normalized_shape.size();
|
const int normalized_ndim = normalized_shape.size();
|
||||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||||
const int axis = input_ndim - normalized_ndim;
|
const int axis = input_ndim - normalized_ndim;
|
||||||
|
|
||||||
|
// Properly handle zero-size inputs: the view(1, M, -1) call below breaks on this.
|
||||||
|
if (input.numel() == 0) {
|
||||||
|
auto result_type = c10::promoteTypes(input.scalar_type(), kFloat);
|
||||||
|
return std::make_tuple(
|
||||||
|
at::empty_like(input),
|
||||||
|
at::empty_like(input, c10::TensorOptions().dtype(result_type)),
|
||||||
|
at::empty_like(input, c10::TensorOptions().dtype(result_type))
|
||||||
|
);
|
||||||
|
}
|
||||||
at::Tensor input_reshaped = input.view({1, M, -1});
|
at::Tensor input_reshaped = input.view({1, M, -1});
|
||||||
// Unlike Batch Normalization, which applies scalar scale and bias for each
|
// Unlike Batch Normalization, which applies scalar scale and bias for each
|
||||||
// entire channel/plane with the affine option, Layer Normalization applies
|
// entire channel/plane with the affine option, Layer Normalization applies
|
||||||
|
|||||||
@ -965,6 +965,7 @@ class Generator(object):
|
|||||||
# Defined in torch/csrc/utils/python_dispatch.cpp
|
# Defined in torch/csrc/utils/python_dispatch.cpp
|
||||||
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
|
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
|
||||||
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
|
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
|
||||||
|
def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
|
||||||
def _dispatch_has_kernel(name: str) -> _bool: ...
|
def _dispatch_has_kernel(name: str) -> _bool: ...
|
||||||
def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ...
|
def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ...
|
||||||
def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ...
|
def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ...
|
||||||
|
|||||||
@ -110,11 +110,15 @@ def register_decomposition(aten_op, registry=None, *, disable_meta: bool = False
|
|||||||
# which don't have corresponding dispatcher entries, we need
|
# which don't have corresponding dispatcher entries, we need
|
||||||
# to filter those out
|
# to filter those out
|
||||||
and torch._C._dispatch_has_kernel(name)
|
and torch._C._dispatch_has_kernel(name)
|
||||||
# Don't register a meta kernel to any operator that has
|
# Don't register a python meta kernel to any operator that has
|
||||||
# a CompositeImplicitAutograd kernel in core.
|
# should already work with meta tensors today.
|
||||||
# Otherwise we won't be able to run autograd for that operator with the meta backend.
|
# We can check that by seeing if the "computed table" for the operator
|
||||||
and "CompositeImplicitAutograd" not in torch._C._dispatch_dump(name)
|
# has a registration to Meta;
|
||||||
and not torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta")
|
# either through a direct registration, or an indirect one through
|
||||||
|
# an alias dispatch key (e.g. CompositeImplicitAutograd)
|
||||||
|
and not torch._C._dispatch_has_computed_kernel_for_dispatch_key(
|
||||||
|
name, "Meta"
|
||||||
|
)
|
||||||
):
|
):
|
||||||
if any(
|
if any(
|
||||||
a.alias_info is not None and not a.alias_info.is_write
|
a.alias_info is not None and not a.alias_info.is_write
|
||||||
|
|||||||
@ -292,6 +292,8 @@ void initDispatchBindings(PyObject* module) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
|
// Returns whether or not a direct kernel registration exists
|
||||||
|
// for this <op_name, dispatch_key> pair.
|
||||||
"_dispatch_has_kernel_for_dispatch_key",
|
"_dispatch_has_kernel_for_dispatch_key",
|
||||||
[](const char* name, const char* dispatch) -> bool {
|
[](const char* name, const char* dispatch) -> bool {
|
||||||
auto op =
|
auto op =
|
||||||
@ -300,6 +302,22 @@ void initDispatchBindings(PyObject* module) {
|
|||||||
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
|
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
// Returns whether or not there is an entry in the runtime computed
|
||||||
|
// dispatch table, for this <op_name, dispatch_key> pair. For example, if
|
||||||
|
// "op" has a `CompositeImplicitAutograd` kernel, Then
|
||||||
|
// _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
|
||||||
|
// true for all backends that are part of the alias set for
|
||||||
|
// CompositeImplicitAutograd.
|
||||||
|
"_dispatch_has_computed_kernel_for_dispatch_key",
|
||||||
|
[](const char* name, const char* dispatch) -> bool {
|
||||||
|
auto op =
|
||||||
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
||||||
|
TORCH_CHECK(op, "operator ", name, " does not exist");
|
||||||
|
return op->hasComputedKernelForDispatchKey(
|
||||||
|
c10::parseDispatchKey(dispatch));
|
||||||
|
});
|
||||||
|
|
||||||
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
|
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
|
||||||
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
|
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user