mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow Undefined to get kernel from Math/DefaultBackend. (#46352)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46352 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D24319417 Pulled By: ailzhang fbshipit-source-id: de2d7db2cb931b0dcf2fbabd7d292e22cfc5e7b7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
908c23579d
commit
7f458e16ba
@ -221,7 +221,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2.1 Use DefaultBackend kernel if available.
|
// 2.1 Use DefaultBackend kernel if available.
|
||||||
if (isIncludedInAlias(dispatch_key, DispatchKey::DefaultBackend)) {
|
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.
|
||||||
|
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::DefaultBackend)) {
|
||||||
if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::DefaultBackend)) {
|
if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::DefaultBackend)) {
|
||||||
return {*default_backend_registration.value(), "default backend kernel"};
|
return {*default_backend_registration.value(), "default backend kernel"};
|
||||||
}
|
}
|
||||||
@ -236,7 +237,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
|||||||
// when there's no direct registration to its corresponding backend key or DefaultBackend.
|
// when there's no direct registration to its corresponding backend key or DefaultBackend.
|
||||||
// For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration
|
// For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration
|
||||||
// to any of its backends.
|
// to any of its backends.
|
||||||
if (isIncludedInAlias(dispatch_key, DispatchKey::Math)) {
|
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.
|
||||||
|
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::Math)) {
|
||||||
if (auto math_registration = getKernelForDispatchKey(DispatchKey::Math)) {
|
if (auto math_registration = getKernelForDispatchKey(DispatchKey::Math)) {
|
||||||
if (dispatch_key == DispatchKey::AutogradOther
|
if (dispatch_key == DispatchKey::AutogradOther
|
||||||
&& hasKernelForAnyDispatchKey(c10::autogradother_backends)) {
|
&& hasKernelForAnyDispatchKey(c10::autogradother_backends)) {
|
||||||
@ -304,6 +306,11 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
|
|||||||
for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) {
|
for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) {
|
||||||
updateDispatchTableEntry_(dispatcher, k);
|
updateDispatchTableEntry_(dispatcher, k);
|
||||||
}
|
}
|
||||||
|
// Registration to DefaultBackend and Math should be populated to Undefined.
|
||||||
|
// We cannot do this above since Undefined cannot be represented in DispatchKeySet.
|
||||||
|
if (dispatch_key == DispatchKey::Math || dispatch_key == DispatchKey::DefaultBackend) {
|
||||||
|
updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined);
|
||||||
|
}
|
||||||
// Note [Refresh Runtime Autograd entries in dispatchTable_]
|
// Note [Refresh Runtime Autograd entries in dispatchTable_]
|
||||||
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
|
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
|
||||||
if (c10::isBackendDispatchKey(dispatch_key)) {
|
if (c10::isBackendDispatchKey(dispatch_key)) {
|
||||||
@ -324,11 +331,16 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
|
|||||||
//
|
//
|
||||||
void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) {
|
void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) {
|
||||||
// Note [Undefined in dispatchTable_]
|
// Note [Undefined in dispatchTable_]
|
||||||
|
// DispatchKey Undefined is used in runtime:
|
||||||
// (1) it gives people place to specify functionality that should run when there are no dispatch keys,
|
// (1) it gives people place to specify functionality that should run when there are no dispatch keys,
|
||||||
// e.g., an empty TensorList argument
|
// e.g., an op without Tensor inputs or empty TensorList arguments
|
||||||
// (2) it would let us remove the explicit error checking code in the dispatch hotpath, and so when
|
// (2) it would let us remove the explicit error checking code in the dispatch hotpath, and so when
|
||||||
// no dispatch keys are available we just slide into the undefined handler which would then raise
|
// no dispatch keys are available we just slide into the undefined handler which would then raise
|
||||||
// the error message./
|
// the error message.
|
||||||
|
// In the old world of catchAll, the only way to "register" a kernel to Undefined is by registering it to
|
||||||
|
// catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either DefaultBackend
|
||||||
|
// or Math alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, Math)
|
||||||
|
// should return true, it returns false because Undefined cannot be represented in a DispatchKeySet.
|
||||||
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
|
for (uint8_t iter = 0; iter != static_cast<uint8_t>(DispatchKey::NumDispatchKeys); ++iter) {
|
||||||
updateDispatchTable_(dispatcher, static_cast<DispatchKey>(iter));
|
updateDispatchTable_(dispatcher, static_cast<DispatchKey>(iter));
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,7 @@ import re
|
|||||||
Result = namedtuple('Result', 'state table provenance')
|
Result = namedtuple('Result', 'state table provenance')
|
||||||
|
|
||||||
dispatch_keys_to_check = (
|
dispatch_keys_to_check = (
|
||||||
|
'Undefined',
|
||||||
'CPU',
|
'CPU',
|
||||||
'CUDA',
|
'CUDA',
|
||||||
'XLA',
|
'XLA',
|
||||||
@ -341,6 +342,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
|||||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||||
|
|
||||||
self.assertExpectedInline(extracted_table, '''\
|
self.assertExpectedInline(extracted_table, '''\
|
||||||
|
Undefined: default_def_name_t_t [catch all]
|
||||||
CPU: fn_cpu [kernel]
|
CPU: fn_cpu [kernel]
|
||||||
CUDA: default_def_name_t_t [catch all]
|
CUDA: default_def_name_t_t [catch all]
|
||||||
XLA: fn_xla [kernel]
|
XLA: fn_xla [kernel]
|
||||||
@ -372,6 +374,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
|||||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||||
|
|
||||||
self.assertExpectedInline(extracted_table, '''\
|
self.assertExpectedInline(extracted_table, '''\
|
||||||
|
Undefined: default_def_name_t_t [catch all]
|
||||||
CPU: impl_t_t [kernel]
|
CPU: impl_t_t [kernel]
|
||||||
CUDA: default_def_name_t_t [catch all]
|
CUDA: default_def_name_t_t [catch all]
|
||||||
XLA: default_def_name_t_t [catch all]
|
XLA: default_def_name_t_t [catch all]
|
||||||
@ -402,6 +405,7 @@ Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
|||||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||||
|
|
||||||
self.assertExpectedInline(extracted_table, '''\
|
self.assertExpectedInline(extracted_table, '''\
|
||||||
|
Undefined: impl_t_t [math kernel]
|
||||||
CPU: impl_t_t [math kernel]
|
CPU: impl_t_t [math kernel]
|
||||||
CUDA: impl_t_t [math kernel]
|
CUDA: impl_t_t [math kernel]
|
||||||
XLA: impl_t_t [math kernel]
|
XLA: impl_t_t [math kernel]
|
||||||
@ -435,6 +439,7 @@ Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
|||||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||||
|
|
||||||
self.assertExpectedInline(extracted_table, '''\
|
self.assertExpectedInline(extracted_table, '''\
|
||||||
|
Undefined: fn_math [math kernel]
|
||||||
CPU: fn_cpu [kernel]
|
CPU: fn_cpu [kernel]
|
||||||
CUDA: fn_math [math kernel]
|
CUDA: fn_math [math kernel]
|
||||||
XLA: fn_math [math kernel]
|
XLA: fn_math [math kernel]
|
||||||
@ -498,6 +503,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
|||||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||||
|
|
||||||
self.assertExpectedInline(extracted_table, '''\
|
self.assertExpectedInline(extracted_table, '''\
|
||||||
|
Undefined: fn_math [math kernel]
|
||||||
CPU: fn_cpu [kernel]
|
CPU: fn_cpu [kernel]
|
||||||
CUDA: fn_math [math kernel]
|
CUDA: fn_math [math kernel]
|
||||||
XLA: fn_math [math kernel]
|
XLA: fn_math [math kernel]
|
||||||
@ -531,6 +537,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
|||||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||||
|
|
||||||
self.assertExpectedInline(extracted_table, '''\
|
self.assertExpectedInline(extracted_table, '''\
|
||||||
|
Undefined: default_def_name_t_t [catch all]
|
||||||
CPU: fn_cpu [kernel]
|
CPU: fn_cpu [kernel]
|
||||||
CUDA: default_def_name_t_t [catch all]
|
CUDA: default_def_name_t_t [catch all]
|
||||||
XLA: default_def_name_t_t [catch all]
|
XLA: default_def_name_t_t [catch all]
|
||||||
@ -564,6 +571,7 @@ catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
|
|||||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
|
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
|
||||||
|
|
||||||
self.assertExpectedInline(extracted_table, '''\
|
self.assertExpectedInline(extracted_table, '''\
|
||||||
|
Undefined: fn_math [math kernel]
|
||||||
CPU: fn_math [math kernel]
|
CPU: fn_math [math kernel]
|
||||||
CUDA: fn_math [math kernel]
|
CUDA: fn_math [math kernel]
|
||||||
XLA: fn_math [math kernel]
|
XLA: fn_math [math kernel]
|
||||||
@ -597,6 +605,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u
|
|||||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||||
|
|
||||||
self.assertExpectedInline(extracted_table, '''\
|
self.assertExpectedInline(extracted_table, '''\
|
||||||
|
Undefined: fn_defaultbackend [default backend kernel]
|
||||||
CPU: fn_cpu [kernel]
|
CPU: fn_cpu [kernel]
|
||||||
CUDA: fn_defaultbackend [default backend kernel]
|
CUDA: fn_defaultbackend [default backend kernel]
|
||||||
XLA: fn_defaultbackend [default backend kernel]
|
XLA: fn_defaultbackend [default backend kernel]
|
||||||
@ -632,6 +641,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u
|
|||||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
|
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check + ('QuantizedCPU',))
|
||||||
|
|
||||||
self.assertExpectedInline(extracted_table, '''\
|
self.assertExpectedInline(extracted_table, '''\
|
||||||
|
Undefined: fn_defaultbackend [default backend kernel]
|
||||||
CPU: fn_cpu [kernel]
|
CPU: fn_cpu [kernel]
|
||||||
CUDA: fn_defaultbackend [default backend kernel]
|
CUDA: fn_defaultbackend [default backend kernel]
|
||||||
XLA: fn_defaultbackend [default backend kernel]
|
XLA: fn_defaultbackend [default backend kernel]
|
||||||
@ -671,6 +681,7 @@ DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed u
|
|||||||
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check)
|
||||||
|
|
||||||
self.assertExpectedInline(extracted_table, '''\
|
self.assertExpectedInline(extracted_table, '''\
|
||||||
|
Undefined: fn_defaultbackend [default backend kernel]
|
||||||
CPU: fn_cpu [kernel]
|
CPU: fn_cpu [kernel]
|
||||||
CUDA: fn_defaultbackend [default backend kernel]
|
CUDA: fn_defaultbackend [default backend kernel]
|
||||||
XLA: fn_defaultbackend [default backend kernel]
|
XLA: fn_defaultbackend [default backend kernel]
|
||||||
|
Reference in New Issue
Block a user