Allow Undefined to get kernel from Math/DefaultBackend. (#46352)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46352

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D24319417

Pulled By: ailzhang

fbshipit-source-id: de2d7db2cb931b0dcf2fbabd7d292e22cfc5e7b7
This commit is contained in:
Ailing Zhang
2020-10-15 11:14:52 -07:00
committed by Facebook GitHub Bot
parent 908c23579d
commit 7f458e16ba
2 changed files with 27 additions and 4 deletions

View File

@ -221,7 +221,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
} }
// 2.1 Use DefaultBackend kernel if available. // 2.1 Use DefaultBackend kernel if available.
if (isIncludedInAlias(dispatch_key, DispatchKey::DefaultBackend)) { // See Note [Undefined in dispatchTable_] for the special handling for Undefined.
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::DefaultBackend)) {
if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::DefaultBackend)) { if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::DefaultBackend)) {
return {*default_backend_registration.value(), "default backend kernel"}; return {*default_backend_registration.value(), "default backend kernel"};
} }
@ -236,7 +237,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// when there's no direct registration to its corresponding backend key or DefaultBackend. // when there's no direct registration to its corresponding backend key or DefaultBackend.
// For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration // For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration
// to any of its backends. // to any of its backends.
if (isIncludedInAlias(dispatch_key, DispatchKey::Math)) { // See Note [Undefined in dispatchTable_] for the special handling for Undefined.
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::Math)) {
if (auto math_registration = getKernelForDispatchKey(DispatchKey::Math)) { if (auto math_registration = getKernelForDispatchKey(DispatchKey::Math)) {
if (dispatch_key == DispatchKey::AutogradOther if (dispatch_key == DispatchKey::AutogradOther
&& hasKernelForAnyDispatchKey(c10::autogradother_backends)) { && hasKernelForAnyDispatchKey(c10::autogradother_backends)) {
@ -304,6 +306,11 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) { for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) {
updateDispatchTableEntry_(dispatcher, k); updateDispatchTableEntry_(dispatcher, k);
} }
// Registration to DefaultBackend and Math should be populated to Undefined.
// We cannot do this above since Undefined cannot be represented in DispatchKeySet.
if (dispatch_key == DispatchKey::Math || dispatch_key == DispatchKey::DefaultBackend) {
updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined);
}
// Note [Refresh Runtime Autograd entries in dispatchTable_] // Note [Refresh Runtime Autograd entries in dispatchTable_]
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3). // Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
if (c10::isBackendDispatchKey(dispatch_key)) { if (c10::isBackendDispatchKey(dispatch_key)) {
@ -324,11 +331,16 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
// //
void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) { void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) {
// Note [Undefined in dispatchTable_] // Note [Undefined in dispatchTable_]
// DispatchKey Undefined is used in runtime:
// (1) it gives people place to specify functionality that should run when there are no dispatch keys, // (1) it gives people place to specify functionality that should run when there are no dispatch keys,
// e.g., an empty TensorList argument // e.g., an op without Tensor inputs or empty TensorList arguments
// (2) it would let us remove the explicit error checking code in the dispatch hotpath, and so when // (2) it would let us remove the explicit error checking code in the dispatch hotpath, and so when
// no dispatch keys are available we just slide into the undefined handler which would then raise // no dispatch keys are available we just slide into the undefined handler which would then raise
// the error message./ // the error message.
// In the old world of catchAll, the only way to "register" a kernel to Undefined is by registering it to
// catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either DefaultBackend
// or Math alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, Math)
// should return true, it returns false because Undefined cannot be represented in a DispatchKeySet.
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) { for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
updateDispatchTable_(dispatcher, static_cast<DispatchKey>(iter)); updateDispatchTable_(dispatcher, static_cast<DispatchKey>(iter));
} }

View File

@ -25,6 +25,7 @@ import re
Result = namedtuple('Result', 'state table provenance') Result = namedtuple('Result', 'state table provenance')
dispatch_keys_to_check = ( dispatch_keys_to_check = (
'Undefined',
'CPU', 'CPU',
'CUDA', 'CUDA',
'XLA', 'XLA',
@ -341,6 +342,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: default_def_name_t_t [catch all]
CPU: fn_cpu [kernel] CPU: fn_cpu [kernel]
CUDA: default_def_name_t_t [catch all] CUDA: default_def_name_t_t [catch all]
XLA: fn_xla [kernel] XLA: fn_xla [kernel]
@ -372,6 +374,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: default_def_name_t_t [catch all]
CPU: impl_t_t [kernel] CPU: impl_t_t [kernel]
CUDA: default_def_name_t_t [catch all] CUDA: default_def_name_t_t [catch all]
XLA: default_def_name_t_t [catch all] XLA: default_def_name_t_t [catch all]
@ -402,6 +405,7 @@ Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: impl_t_t [math kernel]
CPU: impl_t_t [math kernel] CPU: impl_t_t [math kernel]
CUDA: impl_t_t [math kernel] CUDA: impl_t_t [math kernel]
XLA: impl_t_t [math kernel] XLA: impl_t_t [math kernel]
@ -435,6 +439,7 @@ Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: fn_math [math kernel]
CPU: fn_cpu [kernel] CPU: fn_cpu [kernel]
CUDA: fn_math [math kernel] CUDA: fn_math [math kernel]
XLA: fn_math [math kernel] XLA: fn_math [math kernel]
@ -498,6 +503,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: fn_math [math kernel]
CPU: fn_cpu [kernel] CPU: fn_cpu [kernel]
CUDA: fn_math [math kernel] CUDA: fn_math [math kernel]
XLA: fn_math [math kernel] XLA: fn_math [math kernel]
@ -531,6 +537,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: default_def_name_t_t [catch all]
CPU: fn_cpu [kernel] CPU: fn_cpu [kernel]
CUDA: default_def_name_t_t [catch all] CUDA: default_def_name_t_t [catch all]
XLA: default_def_name_t_t [catch all] XLA: default_def_name_t_t [catch all]
@ -564,6 +571,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',)) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: fn_math [math kernel]
CPU: fn_math [math kernel] CPU: fn_math [math kernel]
CUDA: fn_math [math kernel] CUDA: fn_math [math kernel]
XLA: fn_math [math kernel] XLA: fn_math [math kernel]
@ -597,6 +605,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: fn_defaultbackend [default backend kernel]
CPU: fn_cpu [kernel] CPU: fn_cpu [kernel]
CUDA: fn_defaultbackend [default backend kernel] CUDA: fn_defaultbackend [default backend kernel]
XLA: fn_defaultbackend [default backend kernel] XLA: fn_defaultbackend [default backend kernel]
@ -632,6 +641,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',)) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: fn_defaultbackend [default backend kernel]
CPU: fn_cpu [kernel] CPU: fn_cpu [kernel]
CUDA: fn_defaultbackend [default backend kernel] CUDA: fn_defaultbackend [default backend kernel]
XLA: fn_defaultbackend [default backend kernel] XLA: fn_defaultbackend [default backend kernel]
@ -671,6 +681,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
self.assertExpectedInline(extracted_table, '''\ self.assertExpectedInline(extracted_table, '''\
Undefined: fn_defaultbackend [default backend kernel]
CPU: fn_cpu [kernel] CPU: fn_cpu [kernel]
CUDA: fn_defaultbackend [default backend kernel] CUDA: fn_defaultbackend [default backend kernel]
XLA: fn_defaultbackend [default backend kernel] XLA: fn_defaultbackend [default backend kernel]