skip a second call to shouldUseRecordFunction for BackendSelect ops (#50891)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50891

Test Plan: Imported from OSS

Reviewed By: bhosmer

Differential Revision: D25999514

Pulled By: bdhirsh

fbshipit-source-id: 8a6c17ab502fe463cf3fb38a1e555c64bc5556f0
This commit is contained in:
Brian Hirsh
2021-02-08 18:30:21 -08:00
committed by Facebook GitHub Bot
parent 7b9ca54ecf
commit 2303c244fc
3 changed files with 20 additions and 60 deletions

View File

@ -72,10 +72,6 @@ private:
friend class OperatorHandle;
template<class> friend class TypedOperatorHandle;
// Helper utility function for internal use only.
template<class Return, class... Args>
Return _callWithDispatchKeySet(const TypedOperatorHandle<Return(Args...)>& op, const KernelFunction& kernel, DispatchKeySet dispatchKeySet, Args... args) const;
public:
~Dispatcher();
@ -134,11 +130,6 @@ public:
template<class Return, class... Args>
Return call(const TypedOperatorHandle<Return (Args...)>& op, Args... args) const;
// Like call, but override the default DispatchKey calculation code,
// instead dispatching straight to the provided DispatchKey
template<class Return, class... Args>
C10_ALWAYS_INLINE
Return callWithDispatchKey(const TypedOperatorHandle<Return (Args...)>& op, DispatchKey dispatchKey, Args... args) const;
template<class Return, class... Args>
static Return callWithDispatchKeySlowPath(const TypedOperatorHandle<Return (Args...)>& op, bool pre_sampled, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args);
@ -219,13 +210,6 @@ public:
void checkInvariants() const;
/* Check if operator calls with a given dispatch key
* need to be observed with RecordFunction.
*/
inline static bool shouldRecord(DispatchKey dispatch_key) {
return dispatch_key != DispatchKey::BackendSelect;
}
//
// ------------------------------------------------------------------------
//
@ -375,10 +359,6 @@ public:
return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
}
C10_ALWAYS_INLINE Return callWithDispatchKey(DispatchKey dispatchKey, Args... args) const {
return c10::Dispatcher::singleton().callWithDispatchKey<Return, Args...>(*this, dispatchKey, std::forward<Args>(args)...);
}
C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
return c10::Dispatcher::singleton().redispatch<Return, Args...>(*this, currentDispatchKeySet, std::forward<Args>(args)...);
}
@ -393,38 +373,6 @@ namespace detail {
template<class... Args> inline void unused_arg_(const Args&...) {}
}
template<class Return, class... Args>
inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle<Return(Args...)>& op, DispatchKey dispatchKey, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
// No alias dispatch key is allowed at runtime.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKey));
auto dispatchKeySet = op.operatorIterator_->op.dispatchKeyExtractor()
.template getDispatchKeySetUnboxed<Args...>(
DispatchKeySet(DispatchKeySet::FULL_AFTER, dispatchKey),
args...
);
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKey);
return _callWithDispatchKeySet<Return, Args...>(op, kernel, dispatchKeySet, args...);
}
// Note: benchmarks showed that this function wasn't getting inlined during calls to at::empty
template<class Return, class... Args>
C10_ALWAYS_INLINE Return Dispatcher::_callWithDispatchKeySet(const TypedOperatorHandle<Return(Args...)>& op, const KernelFunction& kernel, DispatchKeySet dispatchKeySet, Args... args) const {
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
// shouldRunRecordFunction checks whether RecordFunction should be executed,
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
// this boolean is passed into RecordFunction to adjust the sampling rates of
// the callbacks
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
return callWithDispatchKeySlowPath<Return, Args...>(op, pre_sampled, dispatchKeySet, kernel, std::forward<Args>(args)...);
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
}
template<class Return, class... Args>
inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<Return(Args...)>& op, bool pre_sampled, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) {
// Check if we need to run callbacks registered with RecordFunction
@ -436,7 +384,7 @@ inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
if (op.operatorIterator_->op.isObserved()) {
if (guard.needsInputs()) {
runRecordFunction(guard, op, dispatchKey, impl::boxArgs(args...));
} else {
@ -458,7 +406,19 @@ C10_ALWAYS_INLINE Return Dispatcher::call(const TypedOperatorHandle<Return(Args.
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId()));
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKeySet.highestPriorityTypeId());
return _callWithDispatchKeySet<Return, Args...>(op, kernel, dispatchKeySet, args...);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
// shouldRunRecordFunction checks whether RecordFunction should be executed,
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
// this boolean is passed into RecordFunction to adjust the sampling rates of
// the callbacks
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
return callWithDispatchKeySlowPath<Return, Args...>(op, pre_sampled, dispatchKeySet, kernel, std::forward<Args>(args)...);
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
}
template<class Return, class... Args>
@ -481,7 +441,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
if (shouldRecord(dispatchKey) && entry.isObserved()) {
if (entry.isObserved()) {
if (guard.needsInputs()) {
runRecordFunction(guard, op, dispatchKey, *stack);
} else {

View File

@ -246,13 +246,13 @@ TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnbox
// Ensure that dispatcher doesn't take the dispatch key from the tensor but from the direct argument instead.
called_kernel_cpu = false;
callOpUnboxedWithDispatchKey<void, Tensor>(*op, c10::DispatchKey::CPU, dummyTensor(c10::DispatchKey::CUDA));
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA));
EXPECT_TRUE(called_kernel_cpu);
// Ensure that disptach key from tensor is not used here.
called_kernel_cpu = false;
expectThrows<c10::Error>([&] {
callOpUnboxedWithDispatchKey<void, Tensor>(*op, c10::DispatchKey::CUDA, dummyTensor(c10::DispatchKey::CPU));
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU));
}, "Could not run '_test::dummy' with arguments from the 'CUDA'"
" backend.");
}

View File

@ -356,9 +356,9 @@ class ComputeBackendSelect:
compute_dk = f"""\
DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
DispatchKey _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask).highestPriorityTypeId();"""
DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
else:
compute_dk = f"DispatchKey _dk = {dispatch_key};"
compute_dk = f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
return f"""\
// aten::{f.func}
C10_ALWAYS_INLINE
@ -367,7 +367,7 @@ C10_ALWAYS_INLINE
.findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
.typed<{dispatcher_sig.type()}>();
{compute_dk}
return op.callWithDispatchKey(_dk, {', '.join(a.expr for a in dispatcher_exprs)});
return op.redispatch(_dk, {', '.join(a.expr for a in dispatcher_exprs)});
}}
"""
elif self.target is Target.REGISTRATION: