mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Follows #137407 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137459 Approved by: https://github.com/Skylion007
341 lines
13 KiB
C++
341 lines
13 KiB
C++
#include <torch/csrc/lazy/python/init.h>
|
|
|
|
#include <ATen/FunctionalTensorWrapper.h>
|
|
#include <c10/core/Device.h>
|
|
#include <torch/csrc/jit/python/pybind.h>
|
|
#include <torch/csrc/lazy/backend/backend_device.h>
|
|
#include <torch/csrc/lazy/backend/backend_interface.h>
|
|
#include <torch/csrc/lazy/core/config.h>
|
|
#include <torch/csrc/lazy/core/debug_util.h>
|
|
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
|
#include <torch/csrc/lazy/core/ir_dump_util.h>
|
|
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
|
#include <torch/csrc/lazy/core/metrics.h>
|
|
#include <torch/csrc/lazy/core/trie.h>
|
|
#include <torch/csrc/lazy/python/python_util.h>
|
|
#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
|
|
#include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
|
|
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
|
|
#endif // FBCODE_CAFFE2 || OVRSOURCE
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch::lazy {
|
|
|
|
// TODO(whc) backend 'device' related APIs are not very clear, this code could
|
|
// be simplified but it should probably be done together with
|
|
// designing/refactoring the overall approach to get/set of default eager/lazy
|
|
// device types
|
|
torch::lazy::BackendDevice GetDeviceOrCurrent(const std::string& device_str) {
|
|
if (device_str.empty()) {
|
|
getBackend()->GetDefaultDeviceType();
|
|
return torch::lazy::BackendDevice();
|
|
}
|
|
return torch::lazy::atenDeviceToBackendDevice(c10::Device(device_str));
|
|
}
|
|
|
|
std::ptrdiff_t GetTensorId(const at::Tensor& tensor) {
|
|
torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
|
|
return lazy_tensor->GetUniqueId();
|
|
}
|
|
|
|
std::string GetTensorsDump(
|
|
const std::vector<at::Tensor>& tensors,
|
|
const std::function<std::string(c10::ArrayRef<const torch::lazy::Node*>)>&
|
|
coverter) {
|
|
std::vector<const torch::lazy::Node*> nodes;
|
|
std::vector<torch::lazy::Value> values;
|
|
for (auto& tensor : tensors) {
|
|
auto inner = at::functionalization::impl::from_functional_tensor(tensor);
|
|
torch::lazy::LazyTensorPtr lazy_tensor =
|
|
torch::lazy::TryGetLtcTensor(inner);
|
|
values.push_back(lazy_tensor->GetIrValue());
|
|
nodes.push_back(values.back().node.get());
|
|
}
|
|
return coverter(nodes);
|
|
}
|
|
|
|
std::vector<torch::lazy::LazyTensorPtr> GetLtcTensors(
|
|
const std::vector<at::Tensor>& tensors,
|
|
bool want_all) {
|
|
std::vector<torch::lazy::LazyTensorPtr> lazy_tensors;
|
|
lazy_tensors.reserve(tensors.size());
|
|
if (want_all) {
|
|
for (auto& tensor : tensors) {
|
|
lazy_tensors.push_back(torch::lazy::TryGetLtcTensor(tensor));
|
|
}
|
|
} else {
|
|
for (auto& tensor : tensors) {
|
|
auto lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
|
|
if (lazy_tensor) {
|
|
lazy_tensors.push_back(lazy_tensor);
|
|
}
|
|
}
|
|
}
|
|
return lazy_tensors;
|
|
}
|
|
|
|
std::string GetTensorsBackendGraph(const std::vector<at::Tensor>& tensors) {
|
|
std::vector<torch::lazy::LazyTensorPtr> lazy_tensors =
|
|
GetLtcTensors(tensors, /*want_all=*/false);
|
|
return torch::lazy::LazyGraphExecutor::Get()->DumpBackendComputation(
|
|
lazy_tensors);
|
|
}
|
|
|
|
void SyncTensors(
|
|
const std::vector<at::Tensor>& tensors,
|
|
const std::vector<std::string>& devices,
|
|
bool wait,
|
|
bool sync_ltc_data) {
|
|
std::vector<torch::lazy::LazyTensorPtr> lazy_tensors =
|
|
GetLtcTensors(tensors, /*want_all=*/false);
|
|
torch::lazy::LazyGraphExecutor::Get()->SyncTensorsGraph(
|
|
&lazy_tensors, devices, wait, sync_ltc_data);
|
|
}
|
|
|
|
void initLazyBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
auto lazy = m.def_submodule("_lazy");
|
|
auto lazy_ts_backend = m.def_submodule("_lazy_ts_backend");
|
|
|
|
lazy.def(
|
|
"_mark_step",
|
|
// TODO(whc) this API should probably change from vector<string> to
|
|
// vector<c10::device> but in a separate PR
|
|
[](const std::string& device_str,
|
|
const std::vector<std::string>& devices,
|
|
bool wait) {
|
|
pybind11::gil_scoped_release no_gil;
|
|
auto backend_device = GetDeviceOrCurrent(device_str);
|
|
torch::lazy::LazyGraphExecutor::Get()->SyncLiveTensorsGraph(
|
|
&backend_device, devices, wait);
|
|
torch::lazy::LazyGraphExecutor::Get()->MarkStep(backend_device);
|
|
},
|
|
py::arg("device") = "",
|
|
py::arg("devices"),
|
|
py::arg("wait") = true);
|
|
lazy.def(
|
|
"_wait_device_ops",
|
|
[](const std::vector<std::string>& devices) {
|
|
pybind11::gil_scoped_release no_gil;
|
|
// TODO: Add support of non-empty devices.
|
|
if (!devices.empty()) {
|
|
LOG(ERROR) << "Non-empty devices are not supported.";
|
|
}
|
|
torch::lazy::LazyGraphExecutor::Get()->WaitDeviceOps({});
|
|
},
|
|
py::arg("devices"));
|
|
lazy.def("_reset_metrics", []() {
|
|
torch::lazy::MetricsArena::Get()->ResetCounters();
|
|
torch::lazy::MetricsArena::Get()->ResetMetrics();
|
|
});
|
|
lazy.def("_counter_names", []() { return torch::lazy::GetCounterNames(); });
|
|
lazy.def(
|
|
"_metrics_report", []() { return torch::lazy::CreateMetricReport(); });
|
|
lazy.def("_counter_value", [](const std::string& name) -> py::object {
|
|
torch::lazy::CounterData* data = torch::lazy::GetCounter(name);
|
|
return data != nullptr ? py::cast<int64_t>(data->Value()) : py::none();
|
|
});
|
|
lazy.def("_get_tensor_id", [](const at::Tensor& tensor) {
|
|
return GetTensorId(tensor);
|
|
});
|
|
|
|
lazy.def(
|
|
"_get_tensors_text",
|
|
[](const std::vector<at::Tensor>& tensors) -> std::string {
|
|
auto coverter = [](c10::ArrayRef<const torch::lazy::Node*> nodes) {
|
|
return torch::lazy::DumpUtil::ToText(nodes);
|
|
};
|
|
return GetTensorsDump(tensors, coverter);
|
|
});
|
|
lazy.def(
|
|
"_get_tensors_dot",
|
|
[](const std::vector<at::Tensor>& tensors) -> std::string {
|
|
auto coverter = [](c10::ArrayRef<const torch::lazy::Node*> nodes) {
|
|
return torch::lazy::DumpUtil::ToDot(nodes);
|
|
};
|
|
return GetTensorsDump(tensors, coverter);
|
|
});
|
|
lazy.def(
|
|
"_get_tensors_backend",
|
|
[](const std::vector<at::Tensor>& tensors) -> std::string {
|
|
return GetTensorsBackendGraph(tensors);
|
|
});
|
|
lazy.def("_get_graph_hash", [](const std::vector<at::Tensor>& tensors) {
|
|
std::vector<LazyTensorPtr> xtensors;
|
|
xtensors.reserve(tensors.size());
|
|
for (auto& tensor : tensors) {
|
|
xtensors.emplace_back(TryGetLtcTensor(tensor));
|
|
}
|
|
auto hash = LazyGraphExecutor::Get()->GetGraphHash(xtensors);
|
|
std::string bin((const char*)&hash, sizeof(hash));
|
|
return py::bytes(bin);
|
|
});
|
|
lazy.def(
|
|
"_sync_multi",
|
|
[](const std::vector<at::Tensor>& tensors,
|
|
const std::vector<std::string>& devices,
|
|
bool wait,
|
|
bool sync_ltc_data) {
|
|
pybind11::gil_scoped_release no_gil;
|
|
SyncTensors(tensors, devices, wait, sync_ltc_data);
|
|
},
|
|
py::arg("tensors"),
|
|
py::arg("devices"),
|
|
py::arg("wait") = true,
|
|
py::arg("sync_ltc_data") = true);
|
|
|
|
lazy.def("_get_force_fallback", []() {
|
|
return torch::lazy::getLTCForceFallback();
|
|
});
|
|
lazy.def("_set_force_fallback", [](std::string newval) {
|
|
torch::lazy::getLTCForceFallback() = std::move(newval);
|
|
});
|
|
lazy.def("_clear_ir_cache", []() { TrieCache::Get()->Clear(); });
|
|
lazy.def("_dump_ir_cache", [](std::string filename) {
|
|
TrieCache::Get()->DumpToDotFile(filename);
|
|
});
|
|
lazy.def("_set_reuse_ir", [](bool val) { FLAGS_torch_lazy_reuse_ir = val; });
|
|
lazy.def("_set_symbolic_shape_mode", [](bool val) {
|
|
FLAGS_ltc_enable_symbolic_shapes = val;
|
|
});
|
|
lazy.def("_get_symbolic_shape_mode", []() {
|
|
return FLAGS_ltc_enable_symbolic_shapes;
|
|
});
|
|
lazy.def("_get_default_device_type", []() {
|
|
return getBackend()->GetDefaultDeviceType()->toString();
|
|
});
|
|
|
|
lazy_ts_backend.def("_init", []() {
|
|
#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
|
|
torch::lazy::InitTorchScriptBackend();
|
|
#else
|
|
TORCH_CHECK(false, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds");
|
|
#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
|
|
});
|
|
|
|
/*
|
|
* Return tensor ids and tensors for DeviceData nodes.
|
|
* TODO(shunting) revisit this API for XLA
|
|
*/
|
|
lazy_ts_backend.def(
|
|
"_get_tensors_ts_device_data_node",
|
|
[](const std::vector<at::Tensor>& tensors)
|
|
-> std::pair<std::vector<int64_t>, std::vector<at::IValue>> {
|
|
#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
|
|
std::vector<const Node*> roots;
|
|
for (auto& tensor : tensors) {
|
|
auto xtensor = TryGetLtcTensor(tensor);
|
|
roots.push_back(xtensor->GetIrValue().node.get());
|
|
}
|
|
auto post_order = Util::ComputePostOrder(roots);
|
|
std::vector<int64_t> tensor_ids;
|
|
std::vector<at::IValue> ivalues;
|
|
|
|
std::unordered_set<BackendData::Handle> data_handles_;
|
|
for (auto nodeptr : post_order) {
|
|
if (nodeptr->op() == *torch::lazy::ltc_device_data) {
|
|
const auto backend_data =
|
|
getBackend()->GetComputationDataFromNode(nodeptr);
|
|
|
|
auto infoptr = backend_data->info();
|
|
auto deviceDataInfoPtr =
|
|
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
|
|
auto* tsDataPtr = (torch::lazy::TSData*)backend_data.get();
|
|
|
|
// dedup DeviceData by handle
|
|
auto handle = tsDataPtr->GetHandle();
|
|
if (!data_handles_.insert(handle).second) {
|
|
continue;
|
|
}
|
|
tensor_ids.push_back(deviceDataInfoPtr->tensor_id);
|
|
/*
|
|
* If the TSData contains a tensor, then the tensor id will uniquely
|
|
* identify the tensor. We use that tensor id to find the tensor in
|
|
* other places: e.g. in the python forward method parameters.
|
|
*
|
|
* If the TSData contains a scalar, the tensor id itself is not
|
|
* important. We reuse the scalar value in future calls.
|
|
*/
|
|
if (tsDataPtr->HasValue()) {
|
|
ivalues.emplace_back(tsDataPtr->data());
|
|
} else {
|
|
TORCH_CHECK(tsDataPtr->scalar.has_value());
|
|
ivalues.emplace_back(tsDataPtr->scalar.value());
|
|
}
|
|
}
|
|
}
|
|
return std::make_pair(tensor_ids, ivalues);
|
|
#else
|
|
TORCH_CHECK(
|
|
false, "TorchScript backend not yet supported in FBCODE builds");
|
|
return std::make_pair(
|
|
std::vector<int64_t>(), std::vector<at::IValue>());
|
|
#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
|
|
});
|
|
// TODO(shunting) revisit this part for XLA
|
|
lazy_ts_backend.def(
|
|
"_run_cached_graph",
|
|
[](const std::string& hash_str,
|
|
const std::vector<at::IValue>& graph_inputs) {
|
|
std::vector<at::Tensor> result;
|
|
#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
|
|
TORCH_CHECK(hash_str.size() == sizeof(hash_t));
|
|
hash_t hash = *(hash_t*)(hash_str.c_str());
|
|
auto cachedComputation =
|
|
LazyGraphExecutor::Get()->GetComputationCache()->Get(hash);
|
|
TORCH_CHECK(
|
|
cachedComputation,
|
|
"Failed to get computation by hash. Maybe the entry get kicked out of the LRU cache"); // TODO implement a fallback mechanism, or make sure those entries never get kicked out
|
|
auto computationPtr =
|
|
(torch::lazy::TSComputation*)cachedComputation->computation.get();
|
|
|
|
std::vector<torch::jit::IValue> stack;
|
|
stack.reserve(graph_inputs.size());
|
|
for (const auto& arg : graph_inputs) {
|
|
stack.emplace_back(arg);
|
|
}
|
|
computationPtr->graph_executor().run(stack);
|
|
result.reserve(stack.size());
|
|
for (torch::jit::IValue elem : stack) {
|
|
result.push_back(elem.toTensor());
|
|
}
|
|
#else
|
|
TORCH_CHECK(
|
|
false, "TorchScript backend not yet supported in FBCODE builds");
|
|
#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
|
|
return result;
|
|
});
|
|
lazy_ts_backend.def("_get_latest_computation_graph", []() {
|
|
#if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
|
|
auto computation = LazyGraphExecutor::Get()
|
|
->GetComputationCache()
|
|
->GetLatest()
|
|
->computation;
|
|
auto ts_computation = dynamic_cast<TSComputation*>(computation.get());
|
|
TORCH_CHECK(ts_computation, "Found non-TSComputation in cache");
|
|
return ts_computation->graph()->toString();
|
|
#else
|
|
TORCH_CHECK(
|
|
false, "TorchScript backend not yet supported in FBCODE builds");
|
|
return "";
|
|
#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
|
|
});
|
|
|
|
// GetPythonFramesFunction() has not ever worked with torchdeploy/multipy
|
|
// possibly becuase GetPythonFrames resolves to external cpython rather
|
|
// than embedded cpython. So far this problem has only been observed
|
|
// internally, so we will just block it off there.
|
|
|
|
#if !(defined(USE_DEPLOY))
|
|
|
|
// When libtorch_python is loaded, we register the python frame getter
|
|
// otherwise, debug util simply omits python frames
|
|
GetPythonFramesFunction() = GetPythonFrames;
|
|
|
|
#endif // USE_DEPLOY
|
|
}
|
|
|
|
} // namespace torch::lazy
|