mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Some small performance fixes for c10 dispatcher (#20472)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20472 ghimport-source-id: d118bf8d48eea3faf241a7288fcad1bb6a5f051f Differential Revision: D15332284 Pulled By: li-roy fbshipit-source-id: a8d9e50a440a7ad3ee730f70c0fcae06ae848cbd
This commit is contained in:
committed by
Facebook Github Bot
parent
d9dcfacd9e
commit
e42665cf39
@ -61,7 +61,7 @@ class KernelTable_ final {
|
||||
if (!emplaced.second) {
|
||||
// Element already existed. Overwrite it.
|
||||
emplaced.first->second = value;
|
||||
AT_WARN("Registered a kernel that overwrote a previoulsy registered kernel with same dispatch key '",
|
||||
AT_WARN("Registered a kernel that overwrote a previously registered kernel with same dispatch key '",
|
||||
detail::dispatch_key_to_string(key), "' for operator '", operator_name ,"'.");
|
||||
}
|
||||
}
|
||||
@ -205,7 +205,7 @@ private:
|
||||
bool is_valid_;
|
||||
|
||||
TensorTypeId get_dispatch_key(const Stack* stack) const {
|
||||
auto first_tensor_arg = torch::jit::peek(
|
||||
const IValue& first_tensor_arg = torch::jit::peek(
|
||||
*stack,
|
||||
0,
|
||||
reverse_index_of_first_tensor_arg_
|
||||
@ -217,8 +217,7 @@ private:
|
||||
}
|
||||
return tensor_list[0].type_id();
|
||||
} else {
|
||||
// TODO Avoid bumping the refcounter
|
||||
return first_tensor_arg.toTensor().type_id();
|
||||
return first_tensor_arg.unsafeToTensorImpl()->type_id();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -132,6 +132,9 @@ struct CAFFE2_API IValue final {
|
||||
bool isTensor() const { return Tag::Tensor == tag; }
|
||||
at::Tensor toTensor() &&;
|
||||
at::Tensor toTensor() const &;
|
||||
at::TensorImpl* unsafeToTensorImpl() const {
|
||||
return static_cast<at::TensorImpl*>(payload.as_intrusive_ptr);
|
||||
}
|
||||
|
||||
const IValue& toIValue() const {
|
||||
return *this;
|
||||
|
@ -219,7 +219,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenRegistering_t
|
||||
testing::internal::CaptureStderr();
|
||||
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<DummyKernel>(), dispatchKey(TensorType1()));
|
||||
std::string output = testing::internal::GetCapturedStderr();
|
||||
EXPECT_THAT(output, testing::HasSubstr("Registered a kernel that overwrote a previoulsy registered kernel with same dispatch key"));
|
||||
EXPECT_THAT(output, testing::HasSubstr("Registered a kernel that overwrote a previously registered kernel with same dispatch key"));
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenCalled_thenCallsNewerKernel) {
|
||||
|
Reference in New Issue
Block a user