[Vulkan] Add Vulkan Rewrite to Transfer Inputs and Outputs to Vulkan and CPU Backends Respectively (#87432)

With this change, we don't have to manually invoke transferring input and output backends when we run vulkan models.

Graph rewrite code based off of:
- 32efff45ba (diff-a473bddb458dc24225866a45092d6eca064eddd256245d93020e48e216eee4d5R160-R179)

Differential Revision: [D39519168](https://our.internmc.facebook.com/intern/diff/D39519168/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39519168/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87432
Approved by: https://github.com/mcr229, https://github.com/digantdesai
This commit is contained in:
Salil Desai
2022-10-30 20:30:57 -07:00
committed by PyTorch MergeBot
parent bc68625151
commit df1cc0ef47
9 changed files with 83 additions and 10 deletions

View File

@ -195,14 +195,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
const bool requires_backend_transfers =
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (size_t i = 0; i < n; i++) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_) {
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_);
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
}
@ -223,14 +225,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
const bool requires_backend_transfers =
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (size_t i = 0; i < n; i++) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_) {
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_);
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
}

View File

@ -158,14 +158,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
const bool requires_backend_transfers =
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_) {
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_);
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
}
@ -187,14 +189,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
const bool requires_backend_transfers =
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
if (at::kVulkan == deviceType_) {
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
inputs.push_back(
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
: std::move(atIValue));
} else {
TORCH_CHECK(at::kCPU == deviceType_);
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
inputs.push_back(std::move(atIValue));
}
}