mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
skip a second call to shouldUseRecordFunction for BackendSelect ops (#50891)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50891 Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D25999514 Pulled By: bdhirsh fbshipit-source-id: 8a6c17ab502fe463cf3fb38a1e555c64bc5556f0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7b9ca54ecf
commit
2303c244fc
@ -72,10 +72,6 @@ private:
|
||||
friend class OperatorHandle;
|
||||
template<class> friend class TypedOperatorHandle;
|
||||
|
||||
// Helper utility function for internal use only.
|
||||
template<class Return, class... Args>
|
||||
Return _callWithDispatchKeySet(const TypedOperatorHandle<Return(Args...)>& op, const KernelFunction& kernel, DispatchKeySet dispatchKeySet, Args... args) const;
|
||||
|
||||
public:
|
||||
~Dispatcher();
|
||||
|
||||
@ -134,11 +130,6 @@ public:
|
||||
template<class Return, class... Args>
|
||||
Return call(const TypedOperatorHandle<Return (Args...)>& op, Args... args) const;
|
||||
|
||||
// Like call, but override the default DispatchKey calculation code,
|
||||
// instead dispatching straight to the provided DispatchKey
|
||||
template<class Return, class... Args>
|
||||
C10_ALWAYS_INLINE
|
||||
Return callWithDispatchKey(const TypedOperatorHandle<Return (Args...)>& op, DispatchKey dispatchKey, Args... args) const;
|
||||
|
||||
template<class Return, class... Args>
|
||||
static Return callWithDispatchKeySlowPath(const TypedOperatorHandle<Return (Args...)>& op, bool pre_sampled, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args);
|
||||
@ -219,13 +210,6 @@ public:
|
||||
|
||||
void checkInvariants() const;
|
||||
|
||||
/* Check if operator calls with a given dispatch key
|
||||
* need to be observed with RecordFunction.
|
||||
*/
|
||||
inline static bool shouldRecord(DispatchKey dispatch_key) {
|
||||
return dispatch_key != DispatchKey::BackendSelect;
|
||||
}
|
||||
|
||||
//
|
||||
// ------------------------------------------------------------------------
|
||||
//
|
||||
@ -375,10 +359,6 @@ public:
|
||||
return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE Return callWithDispatchKey(DispatchKey dispatchKey, Args... args) const {
|
||||
return c10::Dispatcher::singleton().callWithDispatchKey<Return, Args...>(*this, dispatchKey, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
|
||||
return c10::Dispatcher::singleton().redispatch<Return, Args...>(*this, currentDispatchKeySet, std::forward<Args>(args)...);
|
||||
}
|
||||
@ -393,38 +373,6 @@ namespace detail {
|
||||
template<class... Args> inline void unused_arg_(const Args&...) {}
|
||||
}
|
||||
|
||||
template<class Return, class... Args>
|
||||
inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle<Return(Args...)>& op, DispatchKey dispatchKey, Args... args) const {
|
||||
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
|
||||
// No alias dispatch key is allowed at runtime.
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKey));
|
||||
auto dispatchKeySet = op.operatorIterator_->op.dispatchKeyExtractor()
|
||||
.template getDispatchKeySetUnboxed<Args...>(
|
||||
DispatchKeySet(DispatchKeySet::FULL_AFTER, dispatchKey),
|
||||
args...
|
||||
);
|
||||
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKey);
|
||||
return _callWithDispatchKeySet<Return, Args...>(op, kernel, dispatchKeySet, args...);
|
||||
}
|
||||
|
||||
// Note: benchmarks showed that this function wasn't getting inlined during calls to at::empty
|
||||
template<class Return, class... Args>
|
||||
C10_ALWAYS_INLINE Return Dispatcher::_callWithDispatchKeySet(const TypedOperatorHandle<Return(Args...)>& op, const KernelFunction& kernel, DispatchKeySet dispatchKeySet, Args... args) const {
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
// By default, when there're no high-frequency or non-sampled callbacks,
|
||||
// RecordFunction is pre-sampled as a perf optimization;
|
||||
// shouldRunRecordFunction checks whether RecordFunction should be executed,
|
||||
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
|
||||
// this boolean is passed into RecordFunction to adjust the sampling rates of
|
||||
// the callbacks
|
||||
bool pre_sampled = false;
|
||||
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
|
||||
return callWithDispatchKeySlowPath<Return, Args...>(op, pre_sampled, dispatchKeySet, kernel, std::forward<Args>(args)...);
|
||||
}
|
||||
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template<class Return, class... Args>
|
||||
inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<Return(Args...)>& op, bool pre_sampled, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) {
|
||||
// Check if we need to run callbacks registered with RecordFunction
|
||||
@ -436,7 +384,7 @@ inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<
|
||||
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
|
||||
if (C10_UNLIKELY(guard.isActive())) {
|
||||
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
|
||||
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
|
||||
if (op.operatorIterator_->op.isObserved()) {
|
||||
if (guard.needsInputs()) {
|
||||
runRecordFunction(guard, op, dispatchKey, impl::boxArgs(args...));
|
||||
} else {
|
||||
@ -458,7 +406,19 @@ C10_ALWAYS_INLINE Return Dispatcher::call(const TypedOperatorHandle<Return(Args.
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId()));
|
||||
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKeySet.highestPriorityTypeId());
|
||||
return _callWithDispatchKeySet<Return, Args...>(op, kernel, dispatchKeySet, args...);
|
||||
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
// By default, when there're no high-frequency or non-sampled callbacks,
|
||||
// RecordFunction is pre-sampled as a perf optimization;
|
||||
// shouldRunRecordFunction checks whether RecordFunction should be executed,
|
||||
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
|
||||
// this boolean is passed into RecordFunction to adjust the sampling rates of
|
||||
// the callbacks
|
||||
bool pre_sampled = false;
|
||||
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
|
||||
return callWithDispatchKeySlowPath<Return, Args...>(op, pre_sampled, dispatchKeySet, kernel, std::forward<Args>(args)...);
|
||||
}
|
||||
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
|
||||
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template<class Return, class... Args>
|
||||
@ -481,7 +441,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
|
||||
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
|
||||
if (C10_UNLIKELY(guard.isActive())) {
|
||||
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
|
||||
if (shouldRecord(dispatchKey) && entry.isObserved()) {
|
||||
if (entry.isObserved()) {
|
||||
if (guard.needsInputs()) {
|
||||
runRecordFunction(guard, op, dispatchKey, *stack);
|
||||
} else {
|
||||
|
@ -246,13 +246,13 @@ TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnbox
|
||||
|
||||
// Ensure that dispatcher doesn't take the dispatch key from the tensor but from the direct argument instead.
|
||||
called_kernel_cpu = false;
|
||||
callOpUnboxedWithDispatchKey<void, Tensor>(*op, c10::DispatchKey::CPU, dummyTensor(c10::DispatchKey::CUDA));
|
||||
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA));
|
||||
EXPECT_TRUE(called_kernel_cpu);
|
||||
|
||||
// Ensure that disptach key from tensor is not used here.
|
||||
called_kernel_cpu = false;
|
||||
expectThrows<c10::Error>([&] {
|
||||
callOpUnboxedWithDispatchKey<void, Tensor>(*op, c10::DispatchKey::CUDA, dummyTensor(c10::DispatchKey::CPU));
|
||||
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU));
|
||||
}, "Could not run '_test::dummy' with arguments from the 'CUDA'"
|
||||
" backend.");
|
||||
}
|
||||
|
@ -356,9 +356,9 @@ class ComputeBackendSelect:
|
||||
compute_dk = f"""\
|
||||
DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
|
||||
DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
|
||||
DispatchKey _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask).highestPriorityTypeId();"""
|
||||
DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
|
||||
else:
|
||||
compute_dk = f"DispatchKey _dk = {dispatch_key};"
|
||||
compute_dk = f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
|
||||
return f"""\
|
||||
// aten::{f.func}
|
||||
C10_ALWAYS_INLINE
|
||||
@ -367,7 +367,7 @@ C10_ALWAYS_INLINE
|
||||
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
|
||||
.typed<{dispatcher_sig.type()}>();
|
||||
{compute_dk}
|
||||
return op.callWithDispatchKey(_dk, {', '.join(a.expr for a in dispatcher_exprs)});
|
||||
return op.redispatch(_dk, {', '.join(a.expr for a in dispatcher_exprs)});
|
||||
}}
|
||||
"""
|
||||
elif self.target is Target.REGISTRATION:
|
||||
|
Reference in New Issue
Block a user