mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
Compare commits
170 Commits
mwizak/res
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 2f75d3aebe | |||
| cd51ee276b | |||
| 467c21ad9a | |||
| f44a8f5201 | |||
| 95039738e3 | |||
| 4a94591321 | |||
| 5e7272b60a | |||
| 3f0985bf89 | |||
| d96b71c3ef | |||
| 6676fe538f | |||
| 29d16a10f3 | |||
| 70183368c6 | |||
| c374b66c75 | |||
| c601a1ea72 | |||
| faa50fa6c4 | |||
| bbb546f542 | |||
| dc53fc2af2 | |||
| a2d5216c04 | |||
| 26a1088f9f | |||
| 9f3385822d | |||
| 968b72ca2c | |||
| 5ae581762b | |||
| 54911834c4 | |||
| 332e835040 | |||
| 1e1b37ed77 | |||
| d05480c236 | |||
| 8e2c6ff709 | |||
| 3e66bd8fa8 | |||
| f589cb4a72 | |||
| 369df36d49 | |||
| 7aff2fc214 | |||
| b27cc37252 | |||
| 41663a247d | |||
| 80e40a5976 | |||
| 1c337ea84b | |||
| fb2b16422d | |||
| 6d00bd774f | |||
| 2a9ceb2f1f | |||
| 303e7afcd9 | |||
| cead985182 | |||
| c2e1972c18 | |||
| 11d7c79cea | |||
| b32f36ce35 | |||
| 979efcb825 | |||
| 2c60570864 | |||
| 99836e07fb | |||
| faea6584f8 | |||
| 08d95fe5c4 | |||
| 3d383f42e9 | |||
| e57894677e | |||
| 27ac21550e | |||
| 41259dc86e | |||
| 28c48d60aa | |||
| 113b8306a8 | |||
| ab474eebfd | |||
| dbd47d2dae | |||
| 727a1aa849 | |||
| ba32ef92a6 | |||
| c659299f60 | |||
| 313bfeea17 | |||
| 4fd958b566 | |||
| 5710b9d4af | |||
| f3e5113185 | |||
| 599b39676b | |||
| e4d59f3a0a | |||
| fffb30d445 | |||
| cd38ad58a5 | |||
| c309091454 | |||
| 3e53b1fb27 | |||
| bfe4839177 | |||
| 121f110b83 | |||
| ad0668cbb8 | |||
| 7a962a2f0d | |||
| 0d8d2a0360 | |||
| f6795a4922 | |||
| 898c2037ed | |||
| f98737b13a | |||
| 97cf576c2c | |||
| f643f2e78b | |||
| c9607615a4 | |||
| d5e568d6e0 | |||
| aa53182976 | |||
| 2549707053 | |||
| 0625bbf0c7 | |||
| 3b8b6ef6fb | |||
| ddb82b6b96 | |||
| 57f2575735 | |||
| 760f4fb105 | |||
| d1aba50677 | |||
| 98984e1561 | |||
| 46bb41bd37 | |||
| 246b2fd7a0 | |||
| d90e80dd35 | |||
| 1fa721545a | |||
| 8c13a8323a | |||
| f0e9ee0bdc | |||
| f649b7bfbd | |||
| 6e6343130f | |||
| 45f34b99a9 | |||
| 5df659dcf8 | |||
| 53f09a5136 | |||
| 2e6d995297 | |||
| aa65799ee0 | |||
| 02a382b7be | |||
| 6abac60294 | |||
| 4037b4fc22 | |||
| af5bc4e801 | |||
| 108d7f193a | |||
| a8fb34cae5 | |||
| 7b122690be | |||
| 4caf34ab53 | |||
| 3409a1e033 | |||
| 66681eea1b | |||
| b7838168c3 | |||
| 26b6913b3d | |||
| ca192e08bd | |||
| 88a13fdc94 | |||
| d099b63055 | |||
| 5fd8eb6fb8 | |||
| a4a2c1cffb | |||
| 400df72bda | |||
| 33f1963cc4 | |||
| ac68c29b4c | |||
| 2832a51c54 | |||
| 3b3af30466 | |||
| 49df7a7617 | |||
| 92176bfdff | |||
| 2f41576c02 | |||
| 0b4314efe2 | |||
| f092c0e8ce | |||
| 0c224957b6 | |||
| a912652a09 | |||
| afb24fc9c1 | |||
| 7017bcda6e | |||
| 98b5a2fc77 | |||
| c2e569992b | |||
| 4cdcd94061 | |||
| b2727a655f | |||
| 3ae5f28df9 | |||
| 130df3a1d6 | |||
| ffe60c3005 | |||
| f03956e5ae | |||
| 782d543bf7 | |||
| a5d05d3d4e | |||
| 7915feda28 | |||
| 7c8aeffc4a | |||
| 0518f254ed | |||
| ef3adf6eac | |||
| c7e5b56a7d | |||
| 86001ed575 | |||
| b8eadce989 | |||
| 06a7899877 | |||
| 4b8ceddea7 | |||
| d55bb26dc5 | |||
| c93a280d5a | |||
| 0e45af0a50 | |||
| f62063d25c | |||
| 06d1a7fa0b | |||
| 00a6482df6 | |||
| 1d6e7ada97 | |||
| 308f6a05ac | |||
| 4e773f0037 | |||
| 672915aece | |||
| 08affe6664 | |||
| c36201b8b2 | |||
| 347e4a1001 | |||
| c5f436b865 | |||
| 51232f8496 | |||
| a9a875cb5c | |||
| 3ae649453b |
@ -183,6 +183,7 @@ include_patterns = [
|
||||
'benchmarks/instruction_counts/**/*.py',
|
||||
'tools/**/*.py',
|
||||
'torchgen/**/*.py',
|
||||
'torch/utils/pytree/__init__.py',
|
||||
'torch/utils/_pytree.py',
|
||||
'torch/utils/_cxx_pytree.py',
|
||||
'torch/utils/benchmark/utils/common.py',
|
||||
|
||||
@ -195,6 +195,7 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
|
||||
/torch/utils/_pytree.py @XuehaiPan
|
||||
/torch/utils/_cxx_pytree.py @XuehaiPan
|
||||
/torch/utils/pytree/ @XuehaiPan
|
||||
/torch/pytree.py @XuehaiPan
|
||||
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
|
||||
|
||||
# Relating to libtorch ABI
|
||||
|
||||
@ -1,88 +1,78 @@
|
||||
#include <ATen/cuda/CUDAGreenContext.h>
|
||||
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define HAS_CUDA_GREEN_CONTEXT() 1
|
||||
#else
|
||||
#define HAS_CUDA_GREEN_CONTEXT() 0
|
||||
#endif
|
||||
|
||||
namespace at::cuda {
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<GreenContext> GreenContext::create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id) {
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (!device_id.has_value()) {
|
||||
device_id = at::cuda::current_device();
|
||||
}
|
||||
return std::unique_ptr<GreenContext>(new GreenContext(device_id.value(), num_sms));
|
||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
@ -90,7 +80,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
|
||||
// Implement move operations
|
||||
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
@ -101,7 +91,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
}
|
||||
|
||||
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
if (this != &other) {
|
||||
// Clean up current resources
|
||||
if (green_ctx_) {
|
||||
@ -130,7 +120,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
}
|
||||
|
||||
GreenContext::~GreenContext() noexcept{
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
#else
|
||||
@ -138,9 +128,25 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext GreenContext::getContext() const {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
return context_;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx GreenContext::getGreenContext() const {
|
||||
return green_ctx_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void GreenContext::setContext() {
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||
parent_stream_ = current_stream.stream();
|
||||
|
||||
@ -169,7 +175,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
}
|
||||
|
||||
void GreenContext::popContext() {
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
// see above note about stream being hardcoded to the default stream
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||
|
||||
@ -1,38 +1,53 @@
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <cuda.h>
|
||||
|
||||
// Forward declare green context as opaque ptr
|
||||
typedef struct CUgreenCtx_st* CUgreenCtx;
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define CUDA_HAS_GREEN_CONTEXT 1
|
||||
#else
|
||||
#define CUDA_HAS_GREEN_CONTEXT 0
|
||||
#endif
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
class TORCH_CUDA_CPP_API GreenContext {
|
||||
public:
|
||||
// Green context creation
|
||||
static std::unique_ptr<GreenContext> create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id);
|
||||
~GreenContext() noexcept;
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
|
||||
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
||||
|
||||
// Delete copy constructor and assignment
|
||||
GreenContext(const GreenContext&) = delete;
|
||||
GreenContext& operator=(const GreenContext&) = delete;
|
||||
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
~GreenContext() noexcept;
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext getContext() const;
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx getGreenContext() const;
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void setContext();
|
||||
|
||||
void popContext();
|
||||
|
||||
private:
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int32_t device_id_ = -1;
|
||||
CUgreenCtx green_ctx_ = nullptr;
|
||||
CUcontext context_ = nullptr;
|
||||
cudaStream_t parent_stream_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
} // namespace at::cuda
|
||||
|
||||
@ -50,18 +50,35 @@ static inline bool parseLinearFlatten3d() {
|
||||
// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
|
||||
// before passing it to linear operation
|
||||
static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
|
||||
const auto input_sizes = input.sym_sizes();
|
||||
// can't use -1 in reshape because it errors when a dimension is 0
|
||||
c10::SymInt flattened_dim = 1;
|
||||
for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
|
||||
flattened_dim = flattened_dim * input_sizes[i];
|
||||
const auto input_sizes = input.sym_sizes();
|
||||
|
||||
const auto result_flattened = [&]() -> Tensor {
|
||||
const auto input_ncols = input_sizes.back();
|
||||
const auto input_flattened_nrows = [&]() -> c10::SymInt {
|
||||
// can't use -1 in reshape because it errors when a dimension is 0
|
||||
auto flattened_nrows = c10::SymInt{1};
|
||||
for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) {
|
||||
flattened_nrows *= size;
|
||||
}
|
||||
return flattened_nrows;
|
||||
}();
|
||||
|
||||
const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols});
|
||||
if (weight.layout() == c10::kStrided) {
|
||||
return at::addmm(bias, input_flattened, weight.t());
|
||||
} else {
|
||||
// weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
|
||||
// so we transpose the problem.
|
||||
// NOTE: at::matmul handles (dense @ sparse) similarly.
|
||||
const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1);
|
||||
return at::addmm(bias_t, weight, input_flattened.t()).t();
|
||||
}
|
||||
auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
|
||||
const auto result = at::addmm(bias, inp_reshape, weight.t());
|
||||
auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
|
||||
c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
|
||||
sizes_vec.push_back(result.sym_size(1));
|
||||
return result.view_symint(sizes_vec);
|
||||
}();
|
||||
|
||||
// Unflatten flattened row dims
|
||||
auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()};
|
||||
result_sizes.back() = result_flattened.sym_size(1);
|
||||
return result_flattened.view_symint(result_sizes);
|
||||
}
|
||||
|
||||
|
||||
@ -90,15 +107,23 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
|
||||
// Fused op is marginally faster.
|
||||
return at::addmm(*bias, input, weight.t());
|
||||
}
|
||||
if (bias->defined() && !input.is_xla()) {
|
||||
// Also hit the fused path for contiguous 3D input, if not using xla
|
||||
|
||||
const auto is_bias_likely_fusable = (
|
||||
bias->defined() &&
|
||||
// cuBLASLt: will fuse in the epilogue without copies
|
||||
// when input/weight/bias are all strided.
|
||||
// When weight is not strided, bias will not be fused,
|
||||
// but we can still dispatch here to avoid at::matmul
|
||||
// path which will probably use a very similar
|
||||
// flattening optimization.
|
||||
(bias->dim() == 1 && bias->is_contiguous_or_false())
|
||||
);
|
||||
if (is_bias_likely_fusable && !input.is_xla()) {
|
||||
// Also hit the fused path for contiguous nD input, if not using xla
|
||||
// backend. Reshaping/flattening has some performance implications on xla.
|
||||
bool is_contiguous = input.is_contiguous_or_false();
|
||||
if (is_contiguous && input_dim == 3) {
|
||||
if (input.is_contiguous_or_false()) {
|
||||
return _flatten_nd_linear(input, weight, *bias);
|
||||
} else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
|
||||
return _flatten_nd_linear(input, weight, *bias);
|
||||
} else if (parseLinearFlatten3d() && input_dim == 3) {
|
||||
} else if (parseLinearFlatten3d()) {
|
||||
// If user forces flattening via env var
|
||||
const Tensor input_cont = input.contiguous();
|
||||
return _flatten_nd_linear(input_cont, weight, *bias);
|
||||
|
||||
@ -59,6 +59,7 @@ torch.special <special>
|
||||
torch.overrides
|
||||
torch.nativert <nativert>
|
||||
torch.package <package>
|
||||
torch.pytree <pytree>
|
||||
profiler
|
||||
nn.init
|
||||
nn.attention
|
||||
@ -76,6 +77,7 @@ sparse
|
||||
storage
|
||||
torch.testing <testing>
|
||||
torch.utils <utils>
|
||||
torch.utils.pytree
|
||||
torch.utils.benchmark <benchmark_utils>
|
||||
torch.utils.checkpoint <checkpoint>
|
||||
torch.utils.cpp_extension <cpp_extension>
|
||||
|
||||
7
docs/source/pytree.rst
Normal file
7
docs/source/pytree.rst
Normal file
@ -0,0 +1,7 @@
|
||||
torch.pytree
|
||||
============
|
||||
|
||||
.. currentmodule:: torch.pytree
|
||||
|
||||
.. automodule:: torch.pytree
|
||||
:members:
|
||||
7
docs/source/torch.utils.pytree.rst
Normal file
7
docs/source/torch.utils.pytree.rst
Normal file
@ -0,0 +1,7 @@
|
||||
torch.utils.pytree
|
||||
==================
|
||||
|
||||
.. currentmodule:: torch.utils.pytree
|
||||
|
||||
.. automodule:: torch.utils.pytree
|
||||
:members:
|
||||
@ -29,6 +29,7 @@ files =
|
||||
benchmarks/instruction_counts,
|
||||
tools,
|
||||
torch/profiler/_memory_profiler.py,
|
||||
torch/utils/pytree/__init__.py,
|
||||
torch/utils/_pytree.py,
|
||||
torch/utils/_cxx_pytree.py,
|
||||
torch/utils/benchmark/utils/common.py,
|
||||
|
||||
@ -687,6 +687,28 @@
|
||||
"kineto_available",
|
||||
"record_function"
|
||||
],
|
||||
"torch.pytree": [
|
||||
"PyTreeSpec",
|
||||
"register_node",
|
||||
"all",
|
||||
"all_only",
|
||||
"any",
|
||||
"any_only",
|
||||
"flatten",
|
||||
"iter",
|
||||
"leaves",
|
||||
"map",
|
||||
"map_",
|
||||
"map_only",
|
||||
"map_only_",
|
||||
"structure",
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance"
|
||||
],
|
||||
"torch.quantization": [
|
||||
"ABC",
|
||||
"DeQuantStub",
|
||||
|
||||
@ -147,8 +147,8 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
t: "f32[10]" = l_x_ + l_y_
|
||||
|
||||
trace_point_tensor_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_spec
|
||||
trace_point_tensor_input_spec : torch.utils._pytree.TreeSpec = self.trace_point_tensor_input_spec
|
||||
trace_point_tensor_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_spec
|
||||
trace_point_tensor_input_spec : torch.utils.pytree.PyTreeSpec = self.trace_point_tensor_input_spec
|
||||
res: "f32[10]" = torch.ops.higher_order.flat_apply(trace_point_tensor_spec, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_spec = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None
|
||||
return (res,)
|
||||
""", # NOQA: B950
|
||||
|
||||
@ -40,6 +40,7 @@ import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils._pytree as python_pytree
|
||||
import torch.utils.cpp_extension
|
||||
import torch.utils.pytree as generic_pytree
|
||||
from torch import Tensor
|
||||
from torch._C import FileCheck
|
||||
from torch._dynamo import allow_in_graph
|
||||
@ -104,6 +105,7 @@ from torch.testing._internal.logging_utils import logs_to_string
|
||||
|
||||
|
||||
pytree_modules = {
|
||||
"generic": generic_pytree,
|
||||
"python": python_pytree,
|
||||
}
|
||||
if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
|
||||
@ -8245,7 +8245,7 @@ graph():
|
||||
%to : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
|
||||
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%to,), kwargs = {})
|
||||
%_spec_1 : [num_users=1] = get_attr[target=_spec_1]
|
||||
%tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {})
|
||||
%tree_unflatten : [num_users=1] = call_function[target=torch.utils.pytree.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {})
|
||||
return tree_unflatten""",
|
||||
)
|
||||
|
||||
|
||||
@ -249,11 +249,11 @@ def forward(self, x, y):
|
||||
_spec_0 = self._spec_0
|
||||
_spec_1 = self._spec_1
|
||||
_spec_4 = self._spec_4
|
||||
tree_flatten = torch.utils._pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None
|
||||
tree_flatten = torch.utils.pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None
|
||||
getitem = tree_flatten[0]; tree_flatten = None
|
||||
x = getitem[0]
|
||||
y = getitem[1]; getitem = None
|
||||
tree_unflatten_1 = torch.utils._pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
|
||||
tree_unflatten_1 = torch.utils.pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
|
||||
getitem_1 = tree_unflatten_1[0]; tree_unflatten_1 = None
|
||||
getitem_2 = getitem_1[0]
|
||||
getitem_3 = getitem_1[1]; getitem_1 = None
|
||||
@ -261,7 +261,7 @@ def forward(self, x, y):
|
||||
bar = self.bar(foo); foo = None
|
||||
tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_4); bar = _spec_4 = None
|
||||
getitem_10 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None
|
||||
tree_unflatten = torch.utils._pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
|
||||
tree_unflatten = torch.utils.pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
|
||||
return tree_unflatten""",
|
||||
)
|
||||
|
||||
|
||||
@ -6474,11 +6474,7 @@ class CommonTemplate:
|
||||
# Constant folding was explicitly turned off due to issue #108388
|
||||
# Turn it back on for test
|
||||
@unittest.skipIf(config.triton.native_matmul, "native matmul has better precision")
|
||||
@torch._inductor.config.patch(
|
||||
joint_graph_constant_folding=True,
|
||||
# Numerical accuracy failure for triton fp16
|
||||
max_autotune_gemm_backends="ATEN",
|
||||
)
|
||||
@torch._inductor.config.patch(joint_graph_constant_folding=True)
|
||||
def test_remove_no_ops(self):
|
||||
def matmul_with_op(x, y, fn):
|
||||
return fn(x @ y)
|
||||
@ -6906,11 +6902,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
_, (code0, code1) = _run_and_get_stripped_kernels(b, x)
|
||||
self.assertEqual(code0, code1)
|
||||
|
||||
@config.patch(
|
||||
force_disable_caches=True,
|
||||
# Test expects a single (fused) kernel to be generated
|
||||
max_autotune_gemm_backends="ATEN",
|
||||
)
|
||||
@config.patch(force_disable_caches=True)
|
||||
@skip_if_cpp_wrapper("run_and_get_kernels issue")
|
||||
@unittest.skipIf(config.triton.native_matmul, "matmul is now generated")
|
||||
def test_deterministic_codegen_with_suffix(self):
|
||||
@ -14123,8 +14115,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
code_disallowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_disallowed)
|
||||
return code_allowed != code_disallowed
|
||||
|
||||
# If matmul is implemented by triton there is more reuse
|
||||
@config.patch(max_autotune_gemm_backends="ATEN")
|
||||
@unittest.skipIf(config.triton.native_matmul, "matmul is now generated")
|
||||
def test_allow_reuse_disable_if_exceed_peak(self):
|
||||
@torch.compile
|
||||
|
||||
@ -4222,6 +4222,7 @@ class TestCudaMallocAsync(TestCase):
|
||||
ss = torch.cuda.memory._snapshot()
|
||||
|
||||
trace_plot(ss)
|
||||
trace_plot(ss, filter_freed=True)
|
||||
segment_plot(ss)
|
||||
text = json.dumps(ss)
|
||||
|
||||
|
||||
@ -820,12 +820,13 @@ class TestPythonPytree(TestCase):
|
||||
script = """
|
||||
import sys
|
||||
import torch
|
||||
import torch.utils._pytree
|
||||
assert "torch.utils.pytree" in sys.modules
|
||||
assert "torch.utils._pytree" in sys.modules
|
||||
if "torch.utils._cxx_pytree" in sys.modules:
|
||||
raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree")
|
||||
if "optree" in sys.modules:
|
||||
raise RuntimeError("importing torch.utils._pytree should not import optree")
|
||||
if not torch.utils.pytree.PYTORCH_USE_CXX_PYTREE:
|
||||
if "torch.utils._cxx_pytree" in sys.modules:
|
||||
raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree")
|
||||
if "optree" in sys.modules:
|
||||
raise RuntimeError("importing torch.utils._pytree should not import optree")
|
||||
"""
|
||||
try:
|
||||
subprocess.check_output(
|
||||
|
||||
@ -2769,6 +2769,7 @@ if TYPE_CHECKING:
|
||||
_inductor as _inductor,
|
||||
_subclasses as _subclasses,
|
||||
onnx as onnx,
|
||||
pytree as pytree,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -2778,6 +2779,7 @@ else:
|
||||
"_export",
|
||||
# ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
|
||||
"onnx",
|
||||
"pytree",
|
||||
}
|
||||
|
||||
def __getattr__(name):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
"""
|
||||
Python polyfills for torch.utils.pytree
|
||||
"""
|
||||
@ -7,7 +9,6 @@ from __future__ import annotations
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
|
||||
@ -18,7 +19,7 @@ from ..decorators import substitute_in_graph
|
||||
if TYPE_CHECKING:
|
||||
import builtins
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
|
||||
__all__: list[str] = []
|
||||
@ -346,8 +347,10 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
assert callable(self._unflatten_func)
|
||||
return self._unflatten_func(self._metadata, subtrees)
|
||||
|
||||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
||||
return isinstance(obj, PyTreeSpec)
|
||||
def _is_pytreespec_instance(
|
||||
obj: Any, /
|
||||
) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]:
|
||||
return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec))
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree.treespec_leaf,
|
||||
@ -549,10 +552,13 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
)
|
||||
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
if not _is_pytreespec_instance(leaves):
|
||||
raise TypeError(
|
||||
f"Expected `treespec` to be an instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
# Allow passing the PyTreeSpec instance as the first argument
|
||||
leaves, treespec = treespec, leaves
|
||||
return treespec.unflatten(leaves)
|
||||
|
||||
__all__ += ["tree_unflatten"]
|
||||
|
||||
@ -3450,6 +3450,7 @@ MOD_INLINELIST = [
|
||||
"torch.utils._python_dispatch",
|
||||
"torch.utils._pytree",
|
||||
"torch.utils.hooks",
|
||||
"torch.utils.pytree",
|
||||
]
|
||||
assert sorted(set(MOD_INLINELIST)) == MOD_INLINELIST
|
||||
MOD_INLINELIST = set(MOD_INLINELIST)
|
||||
|
||||
@ -39,13 +39,13 @@ from torch.utils._pytree import (
|
||||
_register_pytree_node,
|
||||
Context,
|
||||
FlattenFunc,
|
||||
FromDumpableContextFn,
|
||||
FromDumpableContextFunc,
|
||||
GetAttrKey,
|
||||
KeyPath,
|
||||
keystr,
|
||||
MappingKey,
|
||||
SequenceKey,
|
||||
ToDumpableContextFn,
|
||||
ToDumpableContextFunc,
|
||||
tree_flatten_with_path,
|
||||
UnflattenFunc,
|
||||
)
|
||||
@ -485,8 +485,8 @@ def register_dataclass_as_pytree_node(
|
||||
unflatten_fn: Optional[UnflattenFunc] = None,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
|
||||
return_none_fields: bool = False,
|
||||
) -> None:
|
||||
assert dataclasses.is_dataclass(cls), (
|
||||
|
||||
@ -446,7 +446,43 @@ def _format_viz(data, viz_kind, device):
|
||||
)
|
||||
|
||||
|
||||
def trace_plot(data, device=None, plot_segments=False):
|
||||
def filter_alloc_free_pairs(data):
|
||||
for dev_id in range(len(data["device_traces"])):
|
||||
# set of indexes of trace events for alloc-free pairs
|
||||
filterSet = set()
|
||||
# map from addr to index of alloc event
|
||||
allocMap = {}
|
||||
# set of addrs from free_requested events
|
||||
freeRequested = set()
|
||||
for idx, event in enumerate(data["device_traces"][dev_id]):
|
||||
if event["action"] == "alloc":
|
||||
allocMap[event["addr"]] = idx
|
||||
elif event["action"] == "free_requested":
|
||||
freeRequested.add(event["addr"])
|
||||
if allocMap.get(event["addr"]) is not None:
|
||||
filterSet.add(idx)
|
||||
filterSet.add(allocMap[event["addr"]])
|
||||
allocMap.pop(event["addr"])
|
||||
elif event["action"] == "free_completed":
|
||||
if event["addr"] in freeRequested:
|
||||
freeRequested.remove(event["addr"])
|
||||
filterSet.add(idx)
|
||||
else:
|
||||
print(f"free_completed without free_requested: {event}")
|
||||
|
||||
# Remove events whose index is in filterSet
|
||||
if filterSet:
|
||||
# Create a new list excluding events with indices in filterSet
|
||||
data["device_traces"][dev_id] = [
|
||||
event
|
||||
for idx, event in enumerate(data["device_traces"][dev_id])
|
||||
if idx not in filterSet
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def trace_plot(data, device=None, plot_segments=False, filter_freed=False):
|
||||
"""Generate a visualization over time of the memory usage recorded by the trace as an html file.
|
||||
|
||||
Args:
|
||||
@ -454,10 +490,15 @@ def trace_plot(data, device=None, plot_segments=False):
|
||||
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
|
||||
plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
|
||||
Defaults to False.
|
||||
filter_freed (bool, optional): Filter out alloc-free paired events to only plot allocations that are not freed yet.
|
||||
Defaults to False to plot all trace events.
|
||||
|
||||
Returns:
|
||||
str: HTML of visualization
|
||||
"""
|
||||
if filter_freed:
|
||||
data = filter_alloc_free_pairs(data)
|
||||
|
||||
return _format_viz(
|
||||
data,
|
||||
"Active Memory Timeline"
|
||||
@ -698,6 +739,14 @@ if __name__ == "__main__":
|
||||
"-s", "--segments", action="store_true", help=help
|
||||
)
|
||||
|
||||
help = (
|
||||
"filter out allocation-free pairs to only visualize the allocations that are not freed yet;"
|
||||
"useful to reduce the number of events for large traces for debugging OOM"
|
||||
)
|
||||
trace_plot_a.add_argument(
|
||||
"-f", "--filter_freed", action="store_true", help=help
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def _read(name):
|
||||
@ -734,7 +783,12 @@ if __name__ == "__main__":
|
||||
data = _read(args.input)
|
||||
_write(
|
||||
args.output,
|
||||
trace_plot(data, device=args.device, plot_segments=args.segments),
|
||||
trace_plot(
|
||||
data,
|
||||
device=args.device,
|
||||
plot_segments=args.segments,
|
||||
filter_freed=args.filter_freed,
|
||||
),
|
||||
)
|
||||
elif args.action == "segment_plot":
|
||||
data = _read(args.input)
|
||||
|
||||
@ -42,7 +42,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
|
||||
%foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
|
||||
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {})
|
||||
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
|
||||
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
|
||||
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils.pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
|
||||
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
|
||||
%getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {})
|
||||
%getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})
|
||||
@ -293,7 +293,7 @@ def _swap_module_helper(
|
||||
%y : [num_users=1] = placeholder[target=y]
|
||||
|
||||
%_spec_0 : [num_users=1] = get_attr[target=_spec_0]
|
||||
%tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
|
||||
%tree_unflatten : [num_users=2] = call_function[target=torch.utils.pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
|
||||
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {})
|
||||
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
|
||||
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {})
|
||||
|
||||
108
torch/pytree.py
Normal file
108
torch/pytree.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
"""
|
||||
Contains utility functions for working with nested python data structures.
|
||||
|
||||
A *pytree* is Python nested data structure. It is a tree in the sense that
|
||||
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
|
||||
Python values. Furthermore, a pytree should not contain reference cycles.
|
||||
|
||||
pytrees are useful for working with nested collections of Tensors. For example,
|
||||
one can use `map` to map a function over all Tensors inside some nested
|
||||
collection of Tensors and `leaves` to get a flat list of all Tensors
|
||||
inside some nested collection. pytrees are helpful for implementing nested
|
||||
collection support for PyTorch APIs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any as _Any, TYPE_CHECKING as _TYPE_CHECKING
|
||||
|
||||
from torch.utils.pytree import (
|
||||
is_namedtuple,
|
||||
is_namedtuple_class,
|
||||
is_namedtuple_instance,
|
||||
is_structseq,
|
||||
is_structseq_class,
|
||||
is_structseq_instance,
|
||||
PyTree,
|
||||
PyTreeSpec,
|
||||
register_pytree_node as register_node,
|
||||
tree_all as all,
|
||||
tree_all_only as all_only,
|
||||
tree_any as any,
|
||||
tree_any_only as any_only,
|
||||
tree_flatten as flatten,
|
||||
tree_iter as iter,
|
||||
tree_leaves as leaves,
|
||||
tree_map as map,
|
||||
tree_map_ as map_,
|
||||
tree_map_only as map_only,
|
||||
tree_map_only_ as map_only_,
|
||||
tree_structure as structure,
|
||||
tree_unflatten as _tree_unflatten,
|
||||
)
|
||||
|
||||
|
||||
if _TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PyTreeSpec",
|
||||
"register_node",
|
||||
"flatten",
|
||||
"unflatten",
|
||||
"iter",
|
||||
"leaves",
|
||||
"structure",
|
||||
"map",
|
||||
"map_",
|
||||
"map_only",
|
||||
"map_only_",
|
||||
"all",
|
||||
"any",
|
||||
"all_only",
|
||||
"any_only",
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
]
|
||||
|
||||
|
||||
def unflatten(treespec: PyTreeSpec, leaves: Iterable[_Any]) -> PyTree:
|
||||
"""Reconstruct a pytree from the treespec and the leaves.
|
||||
|
||||
The inverse of :func:`flatten`.
|
||||
|
||||
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
|
||||
>>> leaves, treespec = torch.pytree.flatten(tree)
|
||||
>>> tree == torch.pytree.unflatten(treespec, leaves)
|
||||
True
|
||||
|
||||
.. warning::
|
||||
|
||||
This function has a different signature than :func:`torch.utils.pytree.tree_unflatten`.
|
||||
The ``treespec`` argument comes first to have a better :class:`functools.partial` support:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import functools
|
||||
|
||||
unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
|
||||
tree1 = unflatten_fn(leaves1)
|
||||
tree2 = unflatten_fn(leaves2)
|
||||
|
||||
Args:
|
||||
treespec (PyTreeSpec): The treespec to reconstruct.
|
||||
leaves (iterable): The list of leaves to use for reconstruction. The list must match the
|
||||
number of leaves of the treespec.
|
||||
|
||||
Returns:
|
||||
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
|
||||
``treespec``.
|
||||
"""
|
||||
return _tree_unflatten(leaves, treespec)
|
||||
@ -11,6 +11,7 @@ from torch.utils import (
|
||||
data as data,
|
||||
deterministic as deterministic,
|
||||
hooks as hooks,
|
||||
pytree as pytree,
|
||||
)
|
||||
from torch.utils.backend_registration import (
|
||||
generate_methods_for_privateuse1_backend,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
"""
|
||||
Contains utility functions for working with nested python data structures.
|
||||
|
||||
@ -21,13 +23,21 @@ from typing_extensions import deprecated, Self, TypeAlias, TypeIs
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch.torch_version import TorchVersion as _TorchVersion
|
||||
from torch.utils._pytree import (
|
||||
Context,
|
||||
DumpableContext,
|
||||
FlattenFunc,
|
||||
FlattenWithKeysFunc,
|
||||
FromDumpableContextFunc,
|
||||
is_namedtuple,
|
||||
is_namedtuple_class,
|
||||
is_namedtuple_instance,
|
||||
is_structseq,
|
||||
is_structseq_class,
|
||||
is_structseq_instance,
|
||||
KeyEntry,
|
||||
KeyPath,
|
||||
PyTree,
|
||||
ToDumpableContextFunc,
|
||||
UnflattenFunc,
|
||||
)
|
||||
|
||||
|
||||
@ -51,8 +61,9 @@ __all__ = [
|
||||
"FlattenFunc",
|
||||
"UnflattenFunc",
|
||||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
"ToDumpableContextFunc",
|
||||
"FromDumpableContextFunc",
|
||||
"PyTreeSpec",
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"keystr",
|
||||
@ -99,19 +110,8 @@ S = TypeVar("S")
|
||||
U = TypeVar("U")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
TreeSpec: TypeAlias = PyTreeSpec
|
||||
|
||||
Context = Any
|
||||
PyTree = Any
|
||||
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
|
||||
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
|
||||
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
KeyPath = tuple[KeyEntry, ...]
|
||||
FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
|
||||
OpTreeUnflattenFunc: TypeAlias = Callable[[Context, Iterable[Any]], PyTree]
|
||||
|
||||
|
||||
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
||||
@ -128,8 +128,8 @@ def register_pytree_node(
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
|
||||
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node.
|
||||
@ -196,8 +196,8 @@ def _register_pytree_node(
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node for the C++ pytree only.
|
||||
|
||||
@ -247,8 +247,8 @@ def _private_register_pytree_node(
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
|
||||
) -> None:
|
||||
"""This is an internal function that is used to register a pytree node type
|
||||
for the C++ pytree only. End-users should use :func:`register_pytree_node`
|
||||
@ -265,8 +265,10 @@ def _private_register_pytree_node(
|
||||
)
|
||||
|
||||
|
||||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
|
||||
return isinstance(obj, TreeSpec)
|
||||
def _is_pytreespec_instance(
|
||||
obj: Any, /
|
||||
) -> TypeIs[Union[TreeSpec, python_pytree.TreeSpec]]:
|
||||
return isinstance(obj, (TreeSpec, python_pytree.TreeSpec))
|
||||
|
||||
|
||||
def treespec_leaf() -> TreeSpec:
|
||||
@ -393,7 +395,15 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
|
||||
``treespec``.
|
||||
"""
|
||||
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
if not _is_pytreespec_instance(leaves):
|
||||
raise TypeError(
|
||||
f"Expected `treespec` to be an instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
# Allow passing the PyTreeSpec instance as the first argument
|
||||
leaves, treespec = treespec, leaves
|
||||
return treespec.unflatten(leaves)
|
||||
|
||||
|
||||
def tree_iter(
|
||||
@ -972,7 +982,7 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||
"""Serialize a treespec to a JSON string."""
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"treespec_dumps(treespec): Expected `treespec` to be instance of "
|
||||
f"Expected `treespec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
|
||||
@ -993,16 +1003,22 @@ def treespec_loads(serialized: str) -> TreeSpec:
|
||||
return treespec
|
||||
|
||||
|
||||
class _DummyLeaf:
|
||||
class _Asterisk(str):
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls) -> Self:
|
||||
return super().__new__(cls, "*")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "*"
|
||||
return "*" # no quotes
|
||||
|
||||
|
||||
_asterisk = _Asterisk()
|
||||
del _Asterisk
|
||||
|
||||
|
||||
def treespec_pprint(treespec: TreeSpec) -> str:
|
||||
dummy_tree = tree_unflatten(
|
||||
[_DummyLeaf() for _ in range(treespec.num_leaves)],
|
||||
treespec,
|
||||
)
|
||||
dummy_tree = tree_unflatten([_asterisk] * treespec.num_leaves, treespec)
|
||||
return repr(dummy_tree)
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
"""
|
||||
Contains utility functions for working with nested python data structures.
|
||||
|
||||
@ -20,6 +22,7 @@ import functools
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
import warnings
|
||||
@ -36,10 +39,11 @@ from typing import (
|
||||
Optional,
|
||||
overload,
|
||||
Protocol,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import deprecated, NamedTuple, Self
|
||||
from typing_extensions import deprecated, NamedTuple, Self, TypeAlias, TypeIs
|
||||
|
||||
from torch.torch_version import TorchVersion as _TorchVersion
|
||||
|
||||
@ -50,8 +54,9 @@ __all__ = [
|
||||
"FlattenFunc",
|
||||
"UnflattenFunc",
|
||||
"DumpableContext",
|
||||
"ToDumpableContextFn",
|
||||
"FromDumpableContextFn",
|
||||
"ToDumpableContextFunc",
|
||||
"FromDumpableContextFunc",
|
||||
"PyTreeSpec",
|
||||
"TreeSpec",
|
||||
"LeafSpec",
|
||||
"keystr",
|
||||
@ -117,17 +122,21 @@ class EnumEncoder(json.JSONEncoder):
|
||||
return cast(str, super().default(obj))
|
||||
|
||||
|
||||
Context = Any
|
||||
PyTree = Any
|
||||
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
|
||||
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
ToStrFunc = Callable[["TreeSpec", list[str]], str]
|
||||
MaybeFromStrFunc = Callable[[str], Optional[tuple[Any, Context, str]]]
|
||||
KeyPath = tuple[KeyEntry, ...]
|
||||
FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
|
||||
Context: TypeAlias = Any
|
||||
PyTree: TypeAlias = Any
|
||||
FlattenFunc: TypeAlias = Callable[[PyTree], tuple[list[Any], Context]]
|
||||
UnflattenFunc: TypeAlias = Callable[[Iterable[Any], Context], PyTree]
|
||||
DumpableContext: TypeAlias = Any # Any json dumpable text
|
||||
ToDumpableContextFunc: TypeAlias = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFunc: TypeAlias = Callable[[DumpableContext], Context]
|
||||
ToDumpableContextFn: TypeAlias = ToDumpableContextFunc
|
||||
FromDumpableContextFn: TypeAlias = FromDumpableContextFunc
|
||||
ToStrFunc: TypeAlias = Callable[["TreeSpec", list[str]], str]
|
||||
MaybeFromStrFunc: TypeAlias = Callable[[str], Optional[tuple[Any, Context, str]]]
|
||||
KeyPath: TypeAlias = tuple[KeyEntry, ...]
|
||||
FlattenWithKeysFunc: TypeAlias = Callable[
|
||||
[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]
|
||||
]
|
||||
|
||||
|
||||
# A NodeDef holds two callables:
|
||||
@ -160,8 +169,8 @@ SUPPORTED_NODES: dict[type[Any], NodeDef] = {}
|
||||
class _SerializeNodeDef(NamedTuple):
|
||||
typ: type[Any]
|
||||
serialized_type_name: str
|
||||
to_dumpable_context: Optional[ToDumpableContextFn]
|
||||
from_dumpable_context: Optional[FromDumpableContextFn]
|
||||
to_dumpable_context: Optional[ToDumpableContextFunc]
|
||||
from_dumpable_context: Optional[FromDumpableContextFunc]
|
||||
|
||||
|
||||
SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {}
|
||||
@ -198,8 +207,8 @@ def register_pytree_node(
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
|
||||
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node.
|
||||
@ -522,8 +531,8 @@ def _register_pytree_node(
|
||||
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
|
||||
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node for the Python pytree only.
|
||||
@ -589,8 +598,8 @@ def _private_register_pytree_node(
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFunc] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFunc] = None,
|
||||
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
||||
) -> None:
|
||||
"""This is an internal function that is used to register a pytree node type
|
||||
@ -1072,7 +1081,9 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -
|
||||
# children_specs: specs for each child of the root Node
|
||||
# num_leaves: the number of leaves
|
||||
@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False)
|
||||
class TreeSpec:
|
||||
class PyTreeSpec:
|
||||
"""Representing the structure of the pytree."""
|
||||
|
||||
type: Any
|
||||
_context: Context
|
||||
_children: list[Self]
|
||||
@ -1140,6 +1151,7 @@ class TreeSpec:
|
||||
return self._children
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
"""Test whether the treespec represents a leaf."""
|
||||
return self.num_nodes == 1 and self.num_leaves == 1
|
||||
|
||||
def children(self) -> list[Self]:
|
||||
@ -1149,12 +1161,14 @@ class TreeSpec:
|
||||
return self._children[index]
|
||||
|
||||
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
|
||||
def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
|
||||
"""Flatten the subtrees in ``tree`` up to the structure of this treespec and return a list of subtrees."""
|
||||
|
||||
def helper(treespec: TreeSpec, node: PyTree, subtrees: list[PyTree]) -> None:
|
||||
if treespec.is_leaf():
|
||||
subtrees.append(tree)
|
||||
subtrees.append(node)
|
||||
return
|
||||
|
||||
node_type = _get_node_type(tree)
|
||||
node_type = _get_node_type(node)
|
||||
if treespec.type not in BUILTIN_TYPES:
|
||||
# Always require custom node types to match exactly
|
||||
if node_type != treespec.type:
|
||||
@ -1163,7 +1177,7 @@ class TreeSpec:
|
||||
f"expected {treespec.type!r}, but got {node_type!r}.",
|
||||
)
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
children, context = flatten_fn(tree)
|
||||
children, context = flatten_fn(node)
|
||||
if len(children) != treespec.num_children:
|
||||
raise ValueError(
|
||||
f"Node arity mismatch; "
|
||||
@ -1185,10 +1199,10 @@ class TreeSpec:
|
||||
f"Node type mismatch; "
|
||||
f"expected {treespec.type!r}, but got {node_type!r}.",
|
||||
)
|
||||
if len(tree) != treespec.num_children:
|
||||
if len(node) != treespec.num_children:
|
||||
raise ValueError(
|
||||
f"Node arity mismatch; "
|
||||
f"expected {treespec.num_children}, but got {len(tree)}.",
|
||||
f"expected {treespec.num_children}, but got {len(node)}.",
|
||||
)
|
||||
|
||||
if both_standard_dict:
|
||||
@ -1200,7 +1214,7 @@ class TreeSpec:
|
||||
else treespec._context[1]
|
||||
)
|
||||
expected_keys = dict_context
|
||||
got_key_set = set(tree)
|
||||
got_key_set = set(node)
|
||||
expected_key_set = set(expected_keys)
|
||||
if got_key_set != expected_key_set:
|
||||
missing_keys = expected_key_set.difference(got_key_set)
|
||||
@ -1211,11 +1225,11 @@ class TreeSpec:
|
||||
if extra_keys:
|
||||
message += f"; extra key(s): {extra_keys}"
|
||||
raise ValueError(f"Node keys mismatch{message}.")
|
||||
children = [tree[key] for key in expected_keys]
|
||||
children = [node[key] for key in expected_keys]
|
||||
else:
|
||||
# node_type is treespec.type
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
children, context = flatten_fn(tree)
|
||||
children, context = flatten_fn(node)
|
||||
if (
|
||||
node_type is not deque # ignore mismatch of `maxlen` for deque
|
||||
) and context != treespec._context:
|
||||
@ -1232,6 +1246,7 @@ class TreeSpec:
|
||||
return subtrees
|
||||
|
||||
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
|
||||
"""Reconstruct a pytree from the leaves."""
|
||||
if not isinstance(leaves, (list, tuple)):
|
||||
leaves = list(leaves)
|
||||
if len(leaves) != self.num_leaves:
|
||||
@ -1277,6 +1292,9 @@ class TreeSpec:
|
||||
return hash((node_type, hashable_context, tuple(self._children)))
|
||||
|
||||
|
||||
TreeSpec: TypeAlias = PyTreeSpec
|
||||
|
||||
|
||||
# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about
|
||||
# this class with `dataclasses.fields`, etc., while having a simplified
|
||||
# constructor that takes no argument, we wrap with `dataclass(init=True, ...)`
|
||||
@ -1336,6 +1354,39 @@ def treespec_dict(
|
||||
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch.utils._cxx_pytree as cxx
|
||||
|
||||
|
||||
def _is_pytreespec_instance(obj: Any) -> TypeIs[Union[TreeSpec, "cxx.TreeSpec"]]:
|
||||
if isinstance(obj, TreeSpec):
|
||||
return True
|
||||
if "torch.utils._cxx_pytree" in sys.modules:
|
||||
# The C++ pytree module is not always available, so we check if it is loaded.
|
||||
# If the C++ pytree module is loaded, we can check if the treespec
|
||||
# is an instance of the C++ TreeSpec class.
|
||||
from torch.utils._cxx_pytree import TreeSpec as CxxTreeSpec
|
||||
|
||||
if isinstance(obj, CxxTreeSpec):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _ensure_python_treespec_instance(
|
||||
treespec: Union[TreeSpec, "cxx.TreeSpec"],
|
||||
) -> TreeSpec:
|
||||
if isinstance(treespec, TreeSpec):
|
||||
return treespec
|
||||
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"Expected `treespec` to be an instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
dummy_tree = treespec.unflatten([0] * treespec.num_leaves)
|
||||
return tree_structure(dummy_tree)
|
||||
|
||||
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
@ -1366,11 +1417,14 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
"""Given a list of values and a TreeSpec, builds a pytree.
|
||||
This is the inverse operation of `tree_flatten`.
|
||||
"""
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise TypeError(
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
|
||||
f"instance of TreeSpec but got item of type {type(treespec)}.",
|
||||
)
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
if not _is_pytreespec_instance(leaves):
|
||||
raise TypeError(
|
||||
f"Expected `treespec` to be an instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
# Allow passing the PyTreeSpec instance as the first argument
|
||||
leaves, treespec = treespec, leaves
|
||||
return treespec.unflatten(leaves)
|
||||
|
||||
|
||||
@ -1800,34 +1854,30 @@ def _broadcast_to_and_flatten(
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[list[Any]]:
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise AssertionError("treespec must be a TreeSpec")
|
||||
def broadcast_prefix(
|
||||
prefix_tree: PyTree,
|
||||
full_tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> list[Any]:
|
||||
result: list[Any] = []
|
||||
|
||||
if tree_is_leaf(tree, is_leaf=is_leaf):
|
||||
return [tree] * treespec.num_leaves
|
||||
if treespec.is_leaf():
|
||||
def add_leaves(x: Any, subtree: PyTree) -> None:
|
||||
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
|
||||
result.extend([x] * subtreespec.num_leaves)
|
||||
|
||||
tree_map_(
|
||||
add_leaves,
|
||||
prefix_tree,
|
||||
full_tree,
|
||||
is_leaf=is_leaf,
|
||||
)
|
||||
return result
|
||||
|
||||
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
try:
|
||||
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
|
||||
except ValueError:
|
||||
return None
|
||||
node_type = _get_node_type(tree)
|
||||
if node_type != treespec.type:
|
||||
return None
|
||||
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
child_pytrees, context = flatten_fn(tree)
|
||||
|
||||
# Check if the Node is different from the spec
|
||||
if len(child_pytrees) != treespec.num_children or context != treespec._context:
|
||||
return None
|
||||
|
||||
# Recursively flatten the children
|
||||
result: list[Any] = []
|
||||
for child, child_spec in zip(child_pytrees, treespec._children):
|
||||
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
|
||||
if flat is not None:
|
||||
result += flat
|
||||
else:
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -1941,11 +1991,7 @@ _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
|
||||
|
||||
|
||||
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise TypeError(
|
||||
f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
|
||||
f"TreeSpec but got item of type {type(treespec)}.",
|
||||
)
|
||||
treespec = _ensure_python_treespec_instance(treespec)
|
||||
|
||||
if protocol is None:
|
||||
protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
|
||||
@ -1974,16 +2020,22 @@ def treespec_loads(serialized: str) -> TreeSpec:
|
||||
)
|
||||
|
||||
|
||||
class _DummyLeaf:
|
||||
class _Asterisk(str):
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls) -> Self:
|
||||
return super().__new__(cls, "*")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "*"
|
||||
return "*" # no quotes
|
||||
|
||||
|
||||
_asterisk = _Asterisk()
|
||||
del _Asterisk
|
||||
|
||||
|
||||
def treespec_pprint(treespec: TreeSpec) -> str:
|
||||
dummy_tree = tree_unflatten(
|
||||
[_DummyLeaf() for _ in range(treespec.num_leaves)],
|
||||
treespec,
|
||||
)
|
||||
dummy_tree = tree_unflatten([_asterisk] * treespec.num_leaves, treespec)
|
||||
return repr(dummy_tree)
|
||||
|
||||
|
||||
|
||||
216
torch/utils/pytree/__init__.py
Normal file
216
torch/utils/pytree/__init__.py
Normal file
@ -0,0 +1,216 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
"""
|
||||
Contains utility functions for working with nested python data structures.
|
||||
|
||||
A *pytree* is Python nested data structure. It is a tree in the sense that
|
||||
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
|
||||
Python values. Furthermore, a pytree should not contain reference cycles.
|
||||
|
||||
pytrees are useful for working with nested collections of Tensors. For example,
|
||||
one can use `tree_map` to map a function over all Tensors inside some nested
|
||||
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
|
||||
inside some nested collection. pytrees are helpful for implementing nested
|
||||
collection support for PyTorch APIs.
|
||||
"""
|
||||
|
||||
import os as _os
|
||||
import sys as _sys
|
||||
from typing import Any as _Any, Optional as _Optional
|
||||
|
||||
import torch.utils._pytree as python
|
||||
from torch.utils._exposed_in import exposed_in as _exposed_in
|
||||
from torch.utils._pytree import ( # these type aliases are identical in both implementations
|
||||
FlattenFunc,
|
||||
FlattenWithKeysFunc,
|
||||
FromDumpableContextFunc,
|
||||
PyTree,
|
||||
ToDumpableContextFunc,
|
||||
UnflattenFunc,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PyTreeSpec",
|
||||
"register_pytree_node",
|
||||
"tree_flatten",
|
||||
"tree_unflatten",
|
||||
"tree_iter",
|
||||
"tree_leaves",
|
||||
"tree_structure",
|
||||
"tree_map",
|
||||
"tree_map_",
|
||||
"tree_map_only",
|
||||
"tree_map_only_",
|
||||
"tree_all",
|
||||
"tree_any",
|
||||
"tree_all_only",
|
||||
"tree_any_only",
|
||||
"treespec_pprint",
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
]
|
||||
|
||||
|
||||
# NB: Once this variable is read from the environment, the underlying pytree
|
||||
# implementation is frozen. It cannot be swapped to another at runtime.
|
||||
PYTORCH_USE_CXX_PYTREE: bool = _os.getenv("PYTORCH_USE_CXX_PYTREE", "0") not in {
|
||||
"0",
|
||||
"",
|
||||
}
|
||||
|
||||
|
||||
if PYTORCH_USE_CXX_PYTREE:
|
||||
import torch.utils._cxx_pytree as cxx # noqa: F401
|
||||
|
||||
if not python._cxx_pytree_dynamo_traceable:
|
||||
raise ImportError(
|
||||
"Cannot import package `optree`. "
|
||||
"Please install `optree` via `python -m pip install --upgrade optree`. "
|
||||
"Or set the environment variable `PYTORCH_USE_CXX_PYTREE=0`."
|
||||
)
|
||||
|
||||
|
||||
_sys.modules[f"{__name__}.cxx"] = _sys.modules.get("torch.utils._cxx_pytree") # type: ignore[assignment]
|
||||
|
||||
|
||||
if not PYTORCH_USE_CXX_PYTREE:
|
||||
from torch.utils._pytree import (
|
||||
is_namedtuple,
|
||||
is_namedtuple_class,
|
||||
is_namedtuple_instance,
|
||||
is_structseq,
|
||||
is_structseq_class,
|
||||
is_structseq_instance,
|
||||
PyTreeSpec,
|
||||
register_pytree_node as _register_pytree_node,
|
||||
tree_all,
|
||||
tree_all_only,
|
||||
tree_any,
|
||||
tree_any_only,
|
||||
tree_flatten,
|
||||
tree_iter,
|
||||
tree_leaves,
|
||||
tree_map,
|
||||
tree_map_,
|
||||
tree_map_only,
|
||||
tree_map_only_,
|
||||
tree_structure,
|
||||
tree_unflatten,
|
||||
treespec_pprint,
|
||||
)
|
||||
|
||||
PyTreeSpec = _exposed_in(__name__)(PyTreeSpec) # type: ignore[misc]
|
||||
else:
|
||||
from torch.utils._cxx_pytree import ( # type: ignore[assignment,no-redef]
|
||||
is_namedtuple,
|
||||
is_namedtuple_class,
|
||||
is_namedtuple_instance,
|
||||
is_structseq,
|
||||
is_structseq_class,
|
||||
is_structseq_instance,
|
||||
PyTreeSpec,
|
||||
register_pytree_node as _register_pytree_node,
|
||||
tree_all,
|
||||
tree_all_only,
|
||||
tree_any,
|
||||
tree_any_only,
|
||||
tree_flatten,
|
||||
tree_iter,
|
||||
tree_leaves,
|
||||
tree_map,
|
||||
tree_map_,
|
||||
tree_map_only,
|
||||
tree_map_only_,
|
||||
tree_structure,
|
||||
tree_unflatten,
|
||||
treespec_pprint,
|
||||
)
|
||||
|
||||
|
||||
# Change `__module__` of reexported public APIs to 'torch.utils.pytree'
|
||||
__func_names = frozenset(
|
||||
{
|
||||
"tree_all",
|
||||
"tree_all_only",
|
||||
"tree_any",
|
||||
"tree_any_only",
|
||||
"tree_flatten",
|
||||
"tree_iter",
|
||||
"tree_leaves",
|
||||
"tree_map",
|
||||
"tree_map_",
|
||||
"tree_map_only",
|
||||
"tree_map_only_",
|
||||
"tree_structure",
|
||||
"tree_unflatten",
|
||||
"treespec_pprint",
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
}
|
||||
)
|
||||
globals().update(
|
||||
{
|
||||
name: _exposed_in(__name__)(member)
|
||||
for name, member in globals().items()
|
||||
if name in __func_names
|
||||
}
|
||||
)
|
||||
del __func_names, _exposed_in
|
||||
|
||||
|
||||
def register_pytree_node(
|
||||
cls: type[_Any],
|
||||
/,
|
||||
# intentionally use `*_func` over `*_fn` to match annotations
|
||||
flatten_func: FlattenFunc,
|
||||
unflatten_func: UnflattenFunc,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node.
|
||||
|
||||
Args:
|
||||
cls (type): A Python type to treat as an internal pytree node.
|
||||
flatten_func (callable): A function to be used during flattening, taking an instance of
|
||||
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
|
||||
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
|
||||
passed to the ``unflatten_func``.
|
||||
unflatten_func (callable): A function taking two arguments: the unflattened children, and
|
||||
the auxiliary data that was returned by ``flatten_func`` and stored in the treespec.
|
||||
The function should return an instance of ``cls``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> from collections import UserList
|
||||
... class MyList(UserList): pass
|
||||
>>> # Registry a Python type with lambda functions
|
||||
... register_pytree_node(
|
||||
... MyList,
|
||||
... lambda lst: (list(lst), None),
|
||||
... lambda children, _: MyList(children),
|
||||
... )
|
||||
"""
|
||||
_register_pytree_node(
|
||||
cls,
|
||||
flatten_func,
|
||||
unflatten_func,
|
||||
)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> _Any:
|
||||
if name == "cxx":
|
||||
# Lazy import
|
||||
import torch.utils._cxx_pytree as cxx # noqa: F811
|
||||
|
||||
_sys.modules[f"{__name__}.cxx"] = globals()["cxx"] = cxx
|
||||
return cxx
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
8
torch/utils/pytree/cxx.pyi
Normal file
8
torch/utils/pytree/cxx.pyi
Normal file
@ -0,0 +1,8 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
from .._cxx_pytree import * # noqa: F403
|
||||
from .._cxx_pytree import (
|
||||
__all__ as __all__,
|
||||
_broadcast_to_and_flatten as _broadcast_to_and_flatten,
|
||||
KeyPath as KeyPath,
|
||||
)
|
||||
15
torch/utils/pytree/python.pyi
Normal file
15
torch/utils/pytree/python.pyi
Normal file
@ -0,0 +1,15 @@
|
||||
# Owner(s): ["module: pytree"]
|
||||
|
||||
from .._pytree import * # noqa: F403
|
||||
from .._pytree import (
|
||||
__all__ as __all__,
|
||||
_broadcast_to_and_flatten as _broadcast_to_and_flatten,
|
||||
arg_tree_leaves as arg_tree_leaves,
|
||||
BUILTIN_TYPES as BUILTIN_TYPES,
|
||||
GetAttrKey as GetAttrKey,
|
||||
KeyEntry as KeyEntry,
|
||||
KeyPath as KeyPath,
|
||||
MappingKey as MappingKey,
|
||||
SequenceKey as SequenceKey,
|
||||
SUPPORTED_NODES as SUPPORTED_NODES,
|
||||
)
|
||||
Reference in New Issue
Block a user