mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Vulkan] Add Vulkan Rewrite to Transfer Inputs and Outputs to Vulkan and CPU Backends Respectively (#87432)
With this change, we don't have to manually invoke transferring input and output backends when we run vulkan models.
Graph rewrite code based off of:
- 32efff45ba (diff-a473bddb458dc24225866a45092d6eca064eddd256245d93020e48e216eee4d5R160-R179)
Differential Revision: [D39519168](https://our.internmc.facebook.com/intern/diff/D39519168/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39519168/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87432
Approved by: https://github.com/mcr229, https://github.com/digantdesai
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc68625151
commit
df1cc0ef47
@ -195,14 +195,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
const bool requires_backend_transfers =
|
||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
if (at::kVulkan == deviceType_) {
|
||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
||||
inputs.push_back(
|
||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
||||
: std::move(atIValue));
|
||||
} else {
|
||||
TORCH_CHECK(at::kCPU == deviceType_);
|
||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
}
|
||||
@ -223,14 +225,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
const bool requires_backend_transfers =
|
||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
if (at::kVulkan == deviceType_) {
|
||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
||||
inputs.push_back(
|
||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
||||
: std::move(atIValue));
|
||||
} else {
|
||||
TORCH_CHECK(at::kCPU == deviceType_);
|
||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
}
|
||||
|
@ -158,14 +158,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
const bool requires_backend_transfers =
|
||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
if (at::kVulkan == deviceType_) {
|
||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
||||
inputs.push_back(
|
||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
||||
: std::move(atIValue));
|
||||
} else {
|
||||
TORCH_CHECK(at::kCPU == deviceType_);
|
||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
}
|
||||
@ -187,14 +189,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
|
||||
std::vector<at::IValue> inputs{};
|
||||
size_t n = jinputs->size();
|
||||
inputs.reserve(n);
|
||||
const bool requires_backend_transfers =
|
||||
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
|
||||
for (const auto i : c10::irange(n)) {
|
||||
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
|
||||
if (at::kVulkan == deviceType_) {
|
||||
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
|
||||
inputs.push_back(
|
||||
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
|
||||
: std::move(atIValue));
|
||||
} else {
|
||||
TORCH_CHECK(at::kCPU == deviceType_);
|
||||
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
|
||||
inputs.push_back(std::move(atIValue));
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user