mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Some small performance fixes for c10 dispatcher (#20472)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20472 ghimport-source-id: d118bf8d48eea3faf241a7288fcad1bb6a5f051f Differential Revision: D15332284 Pulled By: li-roy fbshipit-source-id: a8d9e50a440a7ad3ee730f70c0fcae06ae848cbd
This commit is contained in:
committed by
Facebook Github Bot
parent
d9dcfacd9e
commit
e42665cf39
@ -61,7 +61,7 @@ class KernelTable_ final {
|
|||||||
if (!emplaced.second) {
|
if (!emplaced.second) {
|
||||||
// Element already existed. Overwrite it.
|
// Element already existed. Overwrite it.
|
||||||
emplaced.first->second = value;
|
emplaced.first->second = value;
|
||||||
AT_WARN("Registered a kernel that overwrote a previoulsy registered kernel with same dispatch key '",
|
AT_WARN("Registered a kernel that overwrote a previously registered kernel with same dispatch key '",
|
||||||
detail::dispatch_key_to_string(key), "' for operator '", operator_name ,"'.");
|
detail::dispatch_key_to_string(key), "' for operator '", operator_name ,"'.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -205,7 +205,7 @@ private:
|
|||||||
bool is_valid_;
|
bool is_valid_;
|
||||||
|
|
||||||
TensorTypeId get_dispatch_key(const Stack* stack) const {
|
TensorTypeId get_dispatch_key(const Stack* stack) const {
|
||||||
auto first_tensor_arg = torch::jit::peek(
|
const IValue& first_tensor_arg = torch::jit::peek(
|
||||||
*stack,
|
*stack,
|
||||||
0,
|
0,
|
||||||
reverse_index_of_first_tensor_arg_
|
reverse_index_of_first_tensor_arg_
|
||||||
@ -217,8 +217,7 @@ private:
|
|||||||
}
|
}
|
||||||
return tensor_list[0].type_id();
|
return tensor_list[0].type_id();
|
||||||
} else {
|
} else {
|
||||||
// TODO Avoid bumping the refcounter
|
return first_tensor_arg.unsafeToTensorImpl()->type_id();
|
||||||
return first_tensor_arg.toTensor().type_id();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -132,6 +132,9 @@ struct CAFFE2_API IValue final {
|
|||||||
bool isTensor() const { return Tag::Tensor == tag; }
|
bool isTensor() const { return Tag::Tensor == tag; }
|
||||||
at::Tensor toTensor() &&;
|
at::Tensor toTensor() &&;
|
||||||
at::Tensor toTensor() const &;
|
at::Tensor toTensor() const &;
|
||||||
|
at::TensorImpl* unsafeToTensorImpl() const {
|
||||||
|
return static_cast<at::TensorImpl*>(payload.as_intrusive_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
const IValue& toIValue() const {
|
const IValue& toIValue() const {
|
||||||
return *this;
|
return *this;
|
||||||
|
@ -219,7 +219,7 @@ TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenRegistering_t
|
|||||||
testing::internal::CaptureStderr();
|
testing::internal::CaptureStderr();
|
||||||
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<DummyKernel>(), dispatchKey(TensorType1()));
|
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<DummyKernel>(), dispatchKey(TensorType1()));
|
||||||
std::string output = testing::internal::GetCapturedStderr();
|
std::string output = testing::internal::GetCapturedStderr();
|
||||||
EXPECT_THAT(output, testing::HasSubstr("Registered a kernel that overwrote a previoulsy registered kernel with same dispatch key"));
|
EXPECT_THAT(output, testing::HasSubstr("Registered a kernel that overwrote a previously registered kernel with same dispatch key"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenCalled_thenCallsNewerKernel) {
|
TEST(OperatorRegistrationTest, givenKernelsWithSameDispatchKey_whenCalled_thenCallsNewerKernel) {
|
||||||
|
Reference in New Issue
Block a user