mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow Undefined to get kernel from Math/DefaultBackend. (#46352)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46352 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D24319417 Pulled By: ailzhang fbshipit-source-id: de2d7db2cb931b0dcf2fbabd7d292e22cfc5e7b7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
908c23579d
commit
7f458e16ba
@ -221,7 +221,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
||||
}
|
||||
|
||||
// 2.1 Use DefaultBackend kernel if available.
|
||||
if (isIncludedInAlias(dispatch_key, DispatchKey::DefaultBackend)) {
|
||||
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.
|
||||
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::DefaultBackend)) {
|
||||
if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::DefaultBackend)) {
|
||||
return {*default_backend_registration.value(), "default backend kernel"};
|
||||
}
|
||||
@ -236,7 +237,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
||||
// when there's no direct registration to its corresponding backend key or DefaultBackend.
|
||||
// For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration
|
||||
// to any of its backends.
|
||||
if (isIncludedInAlias(dispatch_key, DispatchKey::Math)) {
|
||||
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.
|
||||
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::Math)) {
|
||||
if (auto math_registration = getKernelForDispatchKey(DispatchKey::Math)) {
|
||||
if (dispatch_key == DispatchKey::AutogradOther
|
||||
&& hasKernelForAnyDispatchKey(c10::autogradother_backends)) {
|
||||
@ -304,6 +306,11 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
|
||||
for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) {
|
||||
updateDispatchTableEntry_(dispatcher, k);
|
||||
}
|
||||
// Registration to DefaultBackend and Math should be populated to Undefined.
|
||||
// We cannot do this above since Undefined cannot be represented in DispatchKeySet.
|
||||
if (dispatch_key == DispatchKey::Math || dispatch_key == DispatchKey::DefaultBackend) {
|
||||
updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined);
|
||||
}
|
||||
// Note [Refresh Runtime Autograd entries in dispatchTable_]
|
||||
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
|
||||
if (c10::isBackendDispatchKey(dispatch_key)) {
|
||||
@ -324,11 +331,16 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
|
||||
//
|
||||
void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) {
|
||||
// Note [Undefined in dispatchTable_]
|
||||
// DispatchKey Undefined is used in runtime:
|
||||
// (1) it gives people place to specify functionality that should run when there are no dispatch keys,
|
||||
// e.g., an empty TensorList argument
|
||||
// e.g., an op without Tensor inputs or empty TensorList arguments
|
||||
// (2) it would let us remove the explicit error checking code in the dispatch hotpath, and so when
|
||||
// no dispatch keys are available we just slide into the undefined handler which would then raise
|
||||
// the error message./
|
||||
// the error message.
|
||||
// In the old world of catchAll, the only way to "register" a kernel to Undefined is by registering it to
|
||||
// catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either DefaultBackend
|
||||
// or Math alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, Math)
|
||||
// should return true, it returns false because Undefined cannot be represented in a DispatchKeySet.
|
||||
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
|
||||
updateDispatchTable_(dispatcher, static_cast<DispatchKey>(iter));
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ import re
|
||||
Result = namedtuple('Result', 'state table provenance')
|
||||
|
||||
dispatch_keys_to_check = (
|
||||
'Undefined',
|
||||
'CPU',
|
||||
'CUDA',
|
||||
'XLA',
|
||||
@ -341,6 +342,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: default_def_name_t_t [catch all]
|
||||
CPU: fn_cpu [kernel]
|
||||
CUDA: default_def_name_t_t [catch all]
|
||||
XLA: fn_xla [kernel]
|
||||
@ -372,6 +374,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: default_def_name_t_t [catch all]
|
||||
CPU: impl_t_t [kernel]
|
||||
CUDA: default_def_name_t_t [catch all]
|
||||
XLA: default_def_name_t_t [catch all]
|
||||
@ -402,6 +405,7 @@ Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: impl_t_t [math kernel]
|
||||
CPU: impl_t_t [math kernel]
|
||||
CUDA: impl_t_t [math kernel]
|
||||
XLA: impl_t_t [math kernel]
|
||||
@ -435,6 +439,7 @@ Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: fn_math [math kernel]
|
||||
CPU: fn_cpu [kernel]
|
||||
CUDA: fn_math [math kernel]
|
||||
XLA: fn_math [math kernel]
|
||||
@ -498,6 +503,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: fn_math [math kernel]
|
||||
CPU: fn_cpu [kernel]
|
||||
CUDA: fn_math [math kernel]
|
||||
XLA: fn_math [math kernel]
|
||||
@ -531,6 +537,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: default_def_name_t_t [catch all]
|
||||
CPU: fn_cpu [kernel]
|
||||
CUDA: default_def_name_t_t [catch all]
|
||||
XLA: default_def_name_t_t [catch all]
|
||||
@ -564,6 +571,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: fn_math [math kernel]
|
||||
CPU: fn_math [math kernel]
|
||||
CUDA: fn_math [math kernel]
|
||||
XLA: fn_math [math kernel]
|
||||
@ -597,6 +605,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: fn_defaultbackend [default backend kernel]
|
||||
CPU: fn_cpu [kernel]
|
||||
CUDA: fn_defaultbackend [default backend kernel]
|
||||
XLA: fn_defaultbackend [default backend kernel]
|
||||
@ -632,6 +641,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: fn_defaultbackend [default backend kernel]
|
||||
CPU: fn_cpu [kernel]
|
||||
CUDA: fn_defaultbackend [default backend kernel]
|
||||
XLA: fn_defaultbackend [default backend kernel]
|
||||
@ -671,6 +681,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u
|
||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||
|
||||
self.assertExpectedInline(extracted_table, '''\
|
||||
Undefined: fn_defaultbackend [default backend kernel]
|
||||
CPU: fn_cpu [kernel]
|
||||
CUDA: fn_defaultbackend [default backend kernel]
|
||||
XLA: fn_defaultbackend [default backend kernel]
|
||||
|
Reference in New Issue
Block a user