diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 6bb871794c73..ba7d9c4c8c35 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -221,7 +221,8 @@ std::pair OperatorEntry::computeDispatchTab } // 2.1 Use DefaultBackend kernel if available. - if (isIncludedInAlias(dispatch_key, DispatchKey::DefaultBackend)) { + // See Note [Undefined in dispatchTable_] for the special handling for Undefined. + if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::DefaultBackend)) { if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::DefaultBackend)) { return {*default_backend_registration.value(), "default backend kernel"}; } @@ -236,7 +237,8 @@ std::pair OperatorEntry::computeDispatchTab // when there's no direct registration to its corresponding backend key or DefaultBackend. // For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration // to any of its backends. - if (isIncludedInAlias(dispatch_key, DispatchKey::Math)) { + // See Note [Undefined in dispatchTable_] for the special handling for Undefined. + if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::Math)) { if (auto math_registration = getKernelForDispatchKey(DispatchKey::Math)) { if (dispatch_key == DispatchKey::AutogradOther && hasKernelForAnyDispatchKey(c10::autogradother_backends)) { @@ -304,6 +306,11 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) { updateDispatchTableEntry_(dispatcher, k); } + // Registration to DefaultBackend and Math should be populated to Undefined. + // We cannot do this above since Undefined cannot be represented in DispatchKeySet. + if (dispatch_key == DispatchKey::Math || dispatch_key == DispatchKey::DefaultBackend) { + updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined); + } // Note [Refresh Runtime Autograd entries in dispatchTable_] // Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3). if (c10::isBackendDispatchKey(dispatch_key)) { @@ -324,11 +331,16 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp // void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) { // Note [Undefined in dispatchTable_] + // DispatchKey Undefined is used in runtime: // (1) it gives people place to specify functionality that should run when there are no dispatch keys, - // e.g., an empty TensorList argument + // e.g., an op without Tensor inputs or empty TensorList arguments // (2) it would let us remove the explicit error checking code in the dispatch hotpath, and so when // no dispatch keys are available we just slide into the undefined handler which would then raise - // the error message./ + // the error message. + // In the old world of catchAll, the only way to "register" a kernel to Undefined is by registering it to + // catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either DefaultBackend + // or Math alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, Math) + // should return true, it returns false because Undefined cannot be represented in a DispatchKeySet. for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { updateDispatchTable_(dispatcher, static_cast(iter)); } diff --git a/test/test_dispatch.py b/test/test_dispatch.py index ee5b03832acc..fd8fb0e236e2 100644 --- a/test/test_dispatch.py +++ b/test/test_dispatch.py @@ -25,6 +25,7 @@ import re Result = namedtuple('Result', 'state table provenance') dispatch_keys_to_check = ( + 'Undefined', 'CPU', 'CUDA', 'XLA', @@ -341,6 +342,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: default_def_name_t_t [catch all] CPU: fn_cpu [kernel] CUDA: default_def_name_t_t [catch all] XLA: fn_xla [kernel] @@ -372,6 +374,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: default_def_name_t_t [catch all] CPU: impl_t_t [kernel] CUDA: default_def_name_t_t [catch all] XLA: default_def_name_t_t [catch all] @@ -402,6 +405,7 @@ Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: impl_t_t [math kernel] CPU: impl_t_t [math kernel] CUDA: impl_t_t [math kernel] XLA: impl_t_t [math kernel] @@ -435,6 +439,7 @@ Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_math [math kernel] CPU: fn_cpu [kernel] CUDA: fn_math [math kernel] XLA: fn_math [math kernel] @@ -498,6 +503,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_math [math kernel] CPU: fn_cpu [kernel] CUDA: fn_math [math kernel] XLA: fn_math [math kernel] @@ -531,6 +537,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: default_def_name_t_t [catch all] CPU: fn_cpu [kernel] CUDA: default_def_name_t_t [catch all] XLA: default_def_name_t_t [catch all] @@ -564,6 +571,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',)) self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_math [math kernel] CPU: fn_math [math kernel] CUDA: fn_math [math kernel] XLA: fn_math [math kernel] @@ -597,6 +605,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_defaultbackend [default backend kernel] CPU: fn_cpu [kernel] CUDA: fn_defaultbackend [default backend kernel] XLA: fn_defaultbackend [default backend kernel] @@ -632,6 +641,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',)) self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_defaultbackend [default backend kernel] CPU: fn_cpu [kernel] CUDA: fn_defaultbackend [default backend kernel] XLA: fn_defaultbackend [default backend kernel] @@ -671,6 +681,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ +Undefined: fn_defaultbackend [default backend kernel] CPU: fn_cpu [kernel] CUDA: fn_defaultbackend [default backend kernel] XLA: fn_defaultbackend [default backend kernel]