Some small performance fixes for c10 dispatcher (#20472)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20472
ghimport-source-id: d118bf8d48eea3faf241a7288fcad1bb6a5f051f

Differential Revision: D15332284

Pulled By: li-roy

fbshipit-source-id: a8d9e50a440a7ad3ee730f70c0fcae06ae848cbd
This commit is contained in:
Roy Li
2019-05-17 16:55:59 -07:00
committed by Facebook Github Bot
parent d9dcfacd9e
commit e42665cf39
3 changed files with 7 additions and 5 deletions

View File

@ -61,7 +61,7 @@ class KernelTable_ final {
if (!emplaced.second) { if (!emplaced.second) {
// Element already existed. Overwrite it. // Element already existed. Overwrite it.
emplaced.first->second = value; emplaced.first->second = value;
AT_WARN("Registered a kernel that overwrote a previoulsy registered kernel with same dispatch key '", AT_WARN("Registered a kernel that overwrote a previously registered kernel with same dispatch key '",
detail::dispatch_key_to_string(key), "' for operator '", operator_name ,"'."); detail::dispatch_key_to_string(key), "' for operator '", operator_name ,"'.");
} }
} }
@ -205,7 +205,7 @@ private:
bool is_valid_; bool is_valid_;
TensorTypeId get_dispatch_key(const Stack* stack) const { TensorTypeId get_dispatch_key(const Stack* stack) const {
auto first_tensor_arg = torch::jit::peek( const IValue& first_tensor_arg = torch::jit::peek(
*stack, *stack,
0, 0,
reverse_index_of_first_tensor_arg_ reverse_index_of_first_tensor_arg_
@ -217,8 +217,7 @@ private:
} }
return tensor_list[0].type_id(); return tensor_list[0].type_id();
} else { } else {
// TODO Avoid bumping the refcounter return first_tensor_arg.unsafeToTensorImpl()->type_id();
return first_tensor_arg.toTensor().type_id();
} }
} }
}; };

View File

@ -132,6 +132,9 @@ struct CAFFE2_API IValue final {
bool isTensor() const { return Tag::Tensor == tag; } bool isTensor() const { return Tag::Tensor == tag; }
at::Tensor toTensor() &&; at::Tensor toTensor() &&;
at::Tensor toTensor() const &; at::Tensor toTensor() const &;
at::TensorImpl* unsafeToTensorImpl() const {
return static_cast<at::TensorImpl*>(payload.as_intrusive_ptr);
}
const IValue& toIValue() const { const IValue& toIValue() const {
return *this; return *this;

View File

@ -219,7 +219,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenRegistering_t
testing::internal::CaptureStderr(); testing::internal::CaptureStderr();
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<DummyKernel>(), dispatchKey(TensorType1())); c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<DummyKernel>(), dispatchKey(TensorType1()));
std::string output = testing::internal::GetCapturedStderr(); std::string output = testing::internal::GetCapturedStderr();
EXPECT_THAT(output, testing::HasSubstr("Registered a kernel that overwrote a previoulsy registered kernel with same dispatch key")); EXPECT_THAT(output, testing::HasSubstr("Registered a kernel that overwrote a previously registered kernel with same dispatch key"));
} }
TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenCalled_thenCallsNewerKernel) { TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenCalled_thenCallsNewerKernel) {