mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add a cpu_dispatch_key parameter to the cpu_fallback function (#134321)
Fixes #134322 Add a cpu_dispatch_key parameter to the cpu_fallback function to support fallback, for example, to SparseCPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134321 Approved by: https://github.com/albanD
This commit is contained in:
@ -87,7 +87,11 @@ static bool validate_tensor_list(const c10::List<at::Tensor>& tensorlist) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views) {
|
||||
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views,
|
||||
c10::DispatchKey cpu_dispatch_key) {
|
||||
TORCH_CHECK(c10::BackendComponent::CPUBit == c10::toBackendComponent(cpu_dispatch_key),
|
||||
"Expected CPU backend DispatchKey but got ",
|
||||
c10::toString(cpu_dispatch_key));
|
||||
auto& schema_args = op.schema().arguments();
|
||||
const auto num_arguments = schema_args.size();
|
||||
auto arguments = torch::jit::last(stack, num_arguments);
|
||||
@ -143,7 +147,7 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
|
||||
}
|
||||
|
||||
// Step 2: Call the underlying CPU implementation of the operator
|
||||
op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack);
|
||||
op.redispatchBoxed(c10::DispatchKeySet(cpu_dispatch_key), stack);
|
||||
|
||||
// Step 3: We need to take special care to handle mutable aliases properly:
|
||||
// If any input tensors are mutable aliases, we need to
|
||||
|
||||
@ -11,7 +11,8 @@ namespace at::native {
|
||||
|
||||
// This function implements a boxed fallback to CPU.
|
||||
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
|
||||
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false);
|
||||
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false,
|
||||
c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU);
|
||||
|
||||
// This is a helper function that backends can use to directly call their boxed CPU fallback
|
||||
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
|
||||
|
||||
Reference in New Issue
Block a user