Add a cpu_dispatch_key parameter to the cpu_fallback function (#134321)

Fixes #134322
Add a cpu_dispatch_key parameter to the cpu_fallback function to support fallback, for example, to SparseCPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134321
Approved by: https://github.com/albanD
This commit is contained in:
xpfjmj
2024-08-27 19:57:55 +00:00
committed by PyTorch MergeBot
parent adf401f822
commit b744ed6816
2 changed files with 8 additions and 3 deletions

View File

@ -87,7 +87,11 @@ static bool validate_tensor_list(const c10::List<at::Tensor>& tensorlist) {
return flag;
}
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views) {
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views,
c10::DispatchKey cpu_dispatch_key) {
TORCH_CHECK(c10::BackendComponent::CPUBit == c10::toBackendComponent(cpu_dispatch_key),
"Expected CPU backend DispatchKey but got ",
c10::toString(cpu_dispatch_key));
auto& schema_args = op.schema().arguments();
const auto num_arguments = schema_args.size();
auto arguments = torch::jit::last(stack, num_arguments);
@ -143,7 +147,7 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
}
// Step 2: Call the underlying CPU implementation of the operator
op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack);
op.redispatchBoxed(c10::DispatchKeySet(cpu_dispatch_key), stack);
// Step 3: We need to take special care to handle mutable aliases properly:
// If any input tensors are mutable aliases, we need to

View File

@ -11,7 +11,8 @@ namespace at::native {
// This function implements a boxed fallback to CPU.
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false);
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false,
c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU);
// This is a helper function that backends can use to directly call their boxed CPU fallback
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.