Reduce overhead in CUDAGraph Trees (#98529)

Significantly reduces overhead of constructing Tensors and Storages and checking Storage Liveness. Removes the regression for HF models that I tested and removes 75% of overhead of the extremely overhead bound resnet50 training we have in torchbench. (.91x base commit, 1.02x torchinductor default, 1.16x this PR, 1.25 previous cudagraphs impl).

This PR takes care of all of the lower hanging fruit.

- Computes storage aliasing at record time instead of during at runtime. We no longer need to use a runtime storage cache, and can instead index directly into the existing alias if there is one, or construct a new Storage

- Moves the heavyweight C++ calls into a batch - getting storage weakrefs and constructing tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98529
Approved by: https://github.com/jansel, https://github.com/ngimel
This commit is contained in:
Elias Ellison
2023-04-06 23:40:39 +00:00
committed by PyTorch MergeBot
parent 616f50da3a
commit 5c8fea5647
5 changed files with 366 additions and 67 deletions

View File

@ -1,7 +1,13 @@
#include <ATen/ATen.h>
#include <ATen/core/TensorBody.h>
#include <ATen/cuda/CUDAConfig.h>
#include <c10/core/Device.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/UniqueVoidPtr.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <unordered_set>
#if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
@ -1059,6 +1065,79 @@ static void registerCudaPluggableAllocator(PyObject* module) {
return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter());
});
m.def(
"_map_Storage_Refs",
[](const py::sequence& outputs,
const py::list& outputs_persistent_storage,
py::list output_refs,
py::list output_data_ptrs) {
TORCH_CHECK(outputs.size() == outputs_persistent_storage.size());
for (size_t i = 0, end = outputs.size(); i < end; ++i) {
if (!outputs_persistent_storage[i].is_none() ||
outputs[i].is_none()) {
output_refs.append(py::none());
output_data_ptrs.append(py::none());
continue;
}
auto t = outputs[i].cast<at::Tensor>();
c10::StorageImpl* storage = t.storage().unsafeGetStorageImpl();
auto weak = c10::raw::intrusive_ptr::make_weak(storage);
output_refs.append(reinterpret_cast<size_t>(weak));
output_data_ptrs.append(
reinterpret_cast<size_t>(storage->data_ptr().get()));
}
});
m.def(
"_construct_Tensors_From_Storage_and_Metadata",
[](const py::list& storages,
const py::list& metadatas,
py::list& outputs) {
TORCH_CHECK(storages.size() == metadatas.size());
for (size_t i = 0, end = storages.size(); i < end; ++i) {
const auto& maybe_metadata = metadatas[i];
if (maybe_metadata.is_none()) {
outputs.append(py::none());
continue;
}
const py::dict& metadata = maybe_metadata.cast<py::dict>();
c10::Storage s;
if (storages[i].is_none()) {
s = c10::Storage(
c10::Storage::use_byte_size_t(),
metadata["nbytes"].cast<int64_t>(),
at::DataPtr(
reinterpret_cast<void*>(
metadata["data_ptr"].cast<size_t>()),
metadata["device"].cast<c10::Device>()));
} else if (py::isinstance<py::int_>(storages[i])) {
s = outputs[storages[i].cast<int64_t>()]
.cast<at::Tensor>()
.storage();
} else {
s = storages[i].cast<c10::Storage>();
}
auto dtype_arg = metadata["dtype"].ptr();
auto meta = scalarTypeToTypeMeta(toScalarType(dtype_arg));
constexpr c10::DispatchKeySet cuda_dks(c10::DispatchKey::CUDA);
at::Tensor tensor = at::detail::make_tensor_base<c10::TensorImpl>(
std::move(s), cuda_dks, meta);
tensor.unsafeGetTensorImpl()->set_sizes_and_strides(
metadata["size"].cast<std::vector<int64_t>>(),
metadata["stride"].cast<std::vector<int64_t>>());
tensor.unsafeGetTensorImpl()->set_storage_offset(
metadata["storage_offset"].cast<int64_t>());
outputs.append(std::move(tensor));
}
});
m.def(
"_cuda_beginAllocateCurrentStreamToPool",
[](int device, at::cuda::MempoolId_t mempool_id) {