Compare commits

..

16 Commits

Author SHA1 Message Date
4f6a767b3c transform fr traces for ft
Summary:
- the ranks in the default pg config are local ranks
- however fr trace analysis requires them to be global ranks
- so we transform the local ranks to global ranks before the analysis kicks in based on a cli flag
2025-10-29 10:54:55 -07:00
a3fe1825aa Fix incomplete torch.cdist tests (#166507)
Because the `p` value is not used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166507
Approved by: https://github.com/Skylion007
2025-10-29 17:11:07 +00:00
deb776319b [ROCm] Reduce duplication in bfloat16_support_literal definition (#166147)
This PR refactors the bfloat16_support_literal constant in the PyTorch build logic to eliminate duplicated ROCm-specific code.

Previously, there were two nearly identical branches for ROCM_VERSION < 70000 and ROCM_VERSION >= 70000, differing only by a single typedef. These have been unified into one conditional block with a minimal version guard inside. (https://github.com/ROCm/pytorch/pull/2502)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166147
Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily
2025-10-29 16:59:03 +00:00
d7040e6d75 Revert "[dynamo][guards] 1/N Guard selectively for DTensor (#165824)"
This reverts commit ee7434be822cf6e75b4566d8159f550ee233d8ae.

Reverted https://github.com/pytorch/pytorch/pull/165824 on behalf of https://github.com/anijain2305 due to internal job failed ([comment](https://github.com/pytorch/pytorch/pull/165824#issuecomment-3462667536))
2025-10-29 16:52:31 +00:00
35f3572fa4 Revert "[ROCm] Enable group gemm through CK (#166334)"
This reverts commit 1fa520ea654f5fc0b3c65ce6e056dd73442dd65d.

Reverted https://github.com/pytorch/pytorch/pull/166334 on behalf of https://github.com/atalman due to Internal build failures ([comment](https://github.com/pytorch/pytorch/pull/166334#issuecomment-3462640668))
2025-10-29 16:45:02 +00:00
bc5111cd8d [Inductor] Prevent kernel fusion with too many unique inputs and outputs (#166275)
MTIA triton currently has a limit that it can't support the cases when there are too many input/output buffers. This PR adds the limitation to prevent large fusion with many input/output buffer.

Differential Revision: [D85509351](https://our.internmc.facebook.com/intern/diff/D85509351/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166275
Approved by: https://github.com/eellison
ghstack dependencies: #166274
2025-10-29 16:41:34 +00:00
398fdd32bb [Inductor] Lower fallback nodes annotated with "should_fallback" (#166339)
Summary:
This PR introduces an inductor-level fallback mechanism that gives users control over which operations or subgraphs Inductor should lower and which should fall back to preexisting kernels. This has similar motivation as #164776 in providing flexibility to selectively disable Inductor lowering for specific nodes.

The implementation simply adds a check for the `"should_fallback"` metadata annotation on FX graph nodes. If this is set to `True`, the lowerer falls back before attempting the normal lowering path. Note that since these are user-directed fallbacks dependent upon specific, customized conditions, use `add_to_fallback_set=False` to avoid permanent overwrites of inductor's lowering/fallback rules.

Simple example marking nodes for fallback based on custom predicates:

```
def should_fallback_predicate(node: torch.fx.Node, pred: Callable[torch.fx.Node, bool]):
    # Apply predicate and mark for fallback if needed
    if self.predicate(node):
         node.meta["should_fallback"] = True
```

Test Plan: added a CI test

Differential Revision: D85347587

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166339
Approved by: https://github.com/blaine-rister, https://github.com/eellison
2025-10-29 16:33:55 +00:00
5fd1d41e62 Revert "[user-streams] Make device-agnostic streams weakref compatible (#164304)"
This reverts commit bfc2050db975e589795cd3eceaed2e83bf89ad35.

Reverted https://github.com/pytorch/pytorch/pull/164304 on behalf of https://github.com/atalman due to Breaks periodic: test/dynamo/test_streams.py::TestStreams::test_stream_weakref [GH job link](https://github.com/pytorch/pytorch/actions/runs/18909552619/job/53979171605) [HUD commit link](cde81e92b9) ([comment](https://github.com/pytorch/pytorch/pull/164304#issuecomment-3462489278))
2025-10-29 16:09:54 +00:00
c594950e86 Revert "nn.Linear: nD contiguous input + bias -- dispatch to addmm also when weight is sparse (#166071)"
This reverts commit 467c21ad9ae4133c20a3c098a0355e9ac20d48aa.

Reverted https://github.com/pytorch/pytorch/pull/166071 on behalf of https://github.com/atalman due to Multiple CI breakages: test/profiler/test_profiler_tree.py::TestProfilerTree::test_profiler_experimental_tree_with_stack_and_modules [GH job link](https://github.com/pytorch/pytorch/actions/runs/18909087335/job/53976915830) [HUD commit link](467c21ad9a) ([comment](https://github.com/pytorch/pytorch/pull/166071#issuecomment-3462458968))
2025-10-29 16:05:30 +00:00
14102fb1f3 add new line in log (#164240)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164240
Approved by: https://github.com/ColinPeppler, https://github.com/Skylion007, https://github.com/ezyang
ghstack dependencies: #164075
2025-10-29 16:03:32 +00:00
5cdbcb5233 Revert "[User-streams] Make torch.Event weakref compatible (#164522)"
This reverts commit cde81e92b95eee9af2879c9c75f7b03699ca72ad.

Reverted https://github.com/pytorch/pytorch/pull/164522 on behalf of https://github.com/atalman due to Breaks periodic: test/dynamo/test_streams.py::TestStreams::test_stream_weakref [GH job link](https://github.com/pytorch/pytorch/actions/runs/18909552619/job/53979171605) [HUD commit link](cde81e92b9) ([comment](https://github.com/pytorch/pytorch/pull/164522#issuecomment-3462450571))
2025-10-29 16:03:03 +00:00
eae701cad0 Add scaffolding for StableIValue FC/BC (no PoC) (#164332)
1. Add `extension_build_version` and `is_internal` to `FromImpl`/`ToImpl` (this will be useful for future if we need to break the BC of any type) #163832 has the PoC of how we would actually use this system
2. Add `aoti_torch_library_impl_v2` that takes in an additional `extension_build_version` argument, updates callsite in `torch/csrc/stable/library.h` to always pass `TORCH_ABI_VERSION` for this argument
3. Add `extension_build_version` to `from_ivalue` and `to_ivalue` and update all callsites
4. Add a private `_from` and `_to` that pass `is_internal=True` to `FromImpl`/`ToImpl`, making it easier to reason about what is being called from libtorch-land / extension-land

**Note: This PR does not include a linter that tells the user to update from/to if changing the ABI of a type in headeronly, which I intend to do in https://github.com/pytorch/pytorch/pull/163998**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164332
Approved by: https://github.com/janeyx99
ghstack dependencies: #164356, #166373, #163683
2025-10-29 15:41:45 +00:00
8f51556daa Add scaffolding for aoti_torch_call_dispatcher BC with native ops (#163683)
Part 1 of plan in https://docs.google.com/document/d/1MaX51H5aEQE5XnOlnZIpf9oCYwzGrTWkgBACxNzsmWE/edit?usp=sharing

- Upgrade `aoti_torch_call_dispatcher` to v2 with an `extension_build_version`
- Allow registration of StableIValue stack  --> IValue stack adapters for schema changes

#### Note: This PR does not include a linter that tells the user to add the upgrader if the schema changes, which is an important piece that will be added in a separate PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163683
Approved by: https://github.com/janeyx99
ghstack dependencies: #164356, #166373
2025-10-29 15:41:45 +00:00
c0bbda37e8 Move static from_ivalue/to_ivalue to new shim_common.cpp (#166373)
Move `from_ivalue` and `to_ivalue` and their dependents `StableIValueBoxedKernel`, `aoti_torch_library_impl` `aoti_torch_call_dispatcher` into new (non-aoti shim_common.cpp)

This is in prep for the above PRs where I add v2s (`torch_call_dispatcher` and `torch_library_impl`) that are versioning aware

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166373
Approved by: https://github.com/janeyx99
ghstack dependencies: #164356
2025-10-29 15:41:36 +00:00
fefb546b91 Add TORCH_TARGET_VERSION for stable ABI (#164356)
And update it so comparisons can be done by the preprocessor

**Note: We also need to gate in shim.h and figure out how to enforce this**

Differential Revision: [D85683549](https://our.internmc.facebook.com/intern/diff/D85683549)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164356
Approved by: https://github.com/janeyx99
2025-10-29 15:41:28 +00:00
d6d6fa26f5 Revert "bwd pass (#164504)"
This reverts commit f36f372acc28062e0988d84699c62689b0d89a6e.

Reverted https://github.com/pytorch/pytorch/pull/164504 on behalf of https://github.com/jeffdaily due to CI had been clean for both cuda and rocm before merge, broke post merge? ([comment](https://github.com/pytorch/pytorch/pull/164504#issuecomment-3462116676))
2025-10-29 15:10:40 +00:00
42 changed files with 1042 additions and 1394 deletions

View File

@ -50,35 +50,18 @@ static inline bool parseLinearFlatten3d() {
// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
// before passing it to linear operation
static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
const auto input_sizes = input.sym_sizes();
const auto result_flattened = [&]() -> Tensor {
const auto input_ncols = input_sizes.back();
const auto input_flattened_nrows = [&]() -> c10::SymInt {
// can't use -1 in reshape because it errors when a dimension is 0
auto flattened_nrows = c10::SymInt{1};
for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) {
flattened_nrows *= size;
}
return flattened_nrows;
}();
const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols});
if (weight.layout() == c10::kStrided) {
return at::addmm(bias, input_flattened, weight.t());
} else {
// weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
// so we transpose the problem.
// NOTE: at::matmul handles (dense @ sparse) similarly.
const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1);
return at::addmm(bias_t, weight, input_flattened.t()).t();
const auto input_sizes = input.sym_sizes();
// can't use -1 in reshape because it errors when a dimension is 0
c10::SymInt flattened_dim = 1;
for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
flattened_dim = flattened_dim * input_sizes[i];
}
}();
// Unflatten flattened row dims
auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()};
result_sizes.back() = result_flattened.sym_size(1);
return result_flattened.view_symint(result_sizes);
auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
const auto result = at::addmm(bias, inp_reshape, weight.t());
auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
sizes_vec.push_back(result.sym_size(1));
return result.view_symint(sizes_vec);
}
@ -107,23 +90,15 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
// Fused op is marginally faster.
return at::addmm(*bias, input, weight.t());
}
const auto is_bias_likely_fusable = (
bias->defined() &&
// cuBLASLt: will fuse in the epilogue without copies
// when input/weight/bias are all strided.
// When weight is not strided, bias will not be fused,
// but we can still dispatch here to avoid at::matmul
// path which will probably use a very similar
// flattening optimization.
(bias->dim() == 1 && bias->is_contiguous_or_false())
);
if (is_bias_likely_fusable && !input.is_xla()) {
// Also hit the fused path for contiguous nD input, if not using xla
if (bias->defined() && !input.is_xla()) {
// Also hit the fused path for contiguous 3D input, if not using xla
// backend. Reshaping/flattening has some performance implications on xla.
if (input.is_contiguous_or_false()) {
bool is_contiguous = input.is_contiguous_or_false();
if (is_contiguous && input_dim == 3) {
return _flatten_nd_linear(input, weight, *bias);
} else if (parseLinearFlatten3d()) {
} else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
return _flatten_nd_linear(input, weight, *bias);
} else if (parseLinearFlatten3d() && input_dim == 3) {
// If user forces flattening via env var
const Tensor input_cont = input.contiguous();
return _flatten_nd_linear(input_cont, weight, *bias);

View File

@ -22,9 +22,6 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#ifdef USE_ROCM
#include <ATen/native/hip/ck_group_gemm.h>
#endif
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
@ -639,19 +636,12 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
// fast path, no d2h sync needed
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}

View File

@ -1,19 +0,0 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/ScalarType.h>
#include <optional>
namespace at {
namespace hip {
namespace detail {
void group_gemm_ck(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
at::Tensor& out);
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -1,458 +0,0 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/hip/HIPContext.h>
#include <ATen/Tensor.h>
#include <ATen/TensorAccessor.h>
#include <c10/hip/HIPStream.h>
#include <iostream>
#include <vector>
#include <optional>
#include <type_traits>
#include <ck/ck.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/utility/tuple.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at {
namespace hip {
namespace detail {
namespace CkTypes {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
}
template <typename ALayout, typename BLayout, typename DataType>
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
3, 8, 8, 1,
1, 1,
S<1,32,1,8>, 4
>;
template <typename ALayout, typename BLayout, typename DataType>
void launch_grouped_bgemm_ck_impl_dispatch(
const at::Tensor& mat_a,
const at::Tensor& mat_b,
const std::optional<at::Tensor>& offs,
at::Tensor& out)
{
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
using PassThrough = CkTypes::PassThrough;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a_ptrs, p_b_ptrs;
std::vector<void*> p_e_ptrs;
// Note: d_ptrs will be resized after we populate the other vectors
const int mat_a_dim = mat_a.dim();
const int mat_b_dim = mat_b.dim();
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
const size_t a_element_size = mat_a.element_size();
const size_t b_element_size = mat_b.element_size();
const size_t out_element_size = out.element_size();
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
if (mat_a_dim == 2 && mat_b_dim == 2) {
// 2D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
const int M = mat_a.size(0); // number of rows in A
const int N = mat_b.size(1); // number of columns in B
const int K = mat_a.size(1); // columns in A == rows in B
// for 2d*2d input, output is 3d.
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
for (int i = 0; i < num_groups; ++i) {
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
int end_k = offs_accessor[i];
int k = end_k - start_k;
//K dimension are sliced, hence select stride(1) always.
//K dimension is always dimension 1, regardless of memory layout (row/column major)
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
const void* group_b_ptr;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
// Leading dimension = distance between rows = stride(0)
ldb = mat_b.stride(0);
} else {
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
// Leading dimension = distance between columns = stride(1)
ldb = mat_b.stride(1);
}
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
lda = mat_a.stride(0);
} else {
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
lda = mat_a.stride(1);
}
// Output is always row-major in 3D tensor [num_groups, M, N]
// Leading dimension for each group's [M,N] slice = stride(1) = N
ldc = out.stride(1);
size_t output_group_bytes = M * N * out_element_size;
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(k),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc)
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
// 2D*3D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 2d*3d input, output is 2d.
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
const int K = mat_a.size(1); // columns in A
// For 2D-3D case: The output determines N (result width)
const int N = out.size(1); // N is the width of the output tensor
for (int i = 0; i < num_groups; ++i) {
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
int end_m = offs_accessor[i];
int m = end_m - start_m;
// Skip zero-sized groups but continue processing subsequent groups
if (m <= 0) {
continue;
}
// Select A rows for group i: skip start_m rows
const void* group_a_ptr;
int lda;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
lda = mat_a.stride(0); // distance between rows
} else {
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
// Detect stride pattern for A tensor to determine appropriate lda calculation
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
if (a_is_strided_tensor) {
// For strided A tensors: stride(0) gives the actual leading dimension
lda = mat_a.stride(0);
} else {
// For non-strided A tensors: use the M dimension (total rows)
lda = mat_a.size(0); // Total M dimension for column-major layout
}
}
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
int ldb;
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
} else {
// Detect stride pattern to determine appropriate ldb calculation
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
if (is_strided_tensor) {
// For strided tensors: stride(2) gives the actual leading dimension
ldb = mat_b.stride(2);
} else {
// For non-strided tensors: use the N dimension
ldb = mat_b.size(1);
}
}
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
gemm_descs.push_back({
static_cast<ck::index_t>(m),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc)
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
// 3d*3d input, output is 3d - batched matrix multiplication
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
// Each batch is processed as a separate GEMM operation
const int batch_size = mat_a.size(0);
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
int N;
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
N = mat_b.size(2);
} else if (mat_b.size(2) == K) {
// B is [batch, n, k] - transposed layout
N = mat_b.size(1);
} else {
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
}
for (int i = 0; i < batch_size; ++i) {
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B batch for group i: B[i, :, :]
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
// Select output batch for group i: Output[i, :, :]
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
int lda, ldb, ldc;
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
} else {
// Column-major A: leading dimension = distance between columns = stride(2)
lda = mat_a.stride(2);
}
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B: leading dimension = distance between rows
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(1); // stride between K rows
} else {
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
}
} else {
// Column-major B: leading dimension = distance between columns
if (mat_b.size(1) == K) {
// B is [batch, k, n] - normal layout
ldb = mat_b.stride(2); // stride between N columns
} else {
// B is [batch, n, k] - transposed layout
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
}
}
// Output is typically row-major: leading dimension = distance between rows = stride(1)
ldc = out.stride(1);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(N),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc)
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
// 3D*2D case requires offset tensor
auto offs_accessor = offs->accessor<int, 1>();
int num_groups = offs_accessor.size(0);
// 3d*2d input, output is 3d.
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
const int batch_size = mat_a.size(0); // n_groups
const int M = mat_a.size(1); // rows in each A matrix
const int K = mat_a.size(2); // columns in A
// For row-major A and B case: B should be [K, total_N]
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
for (int i = 0; i < num_groups; ++i) {
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
int end_n = offs_accessor[i];
int n = end_n - start_n;
// Skip zero-sized groups but continue processing subsequent groups
if (n <= 0) {
continue;
}
// Select A batch for group i: A[i, :, :]
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
const void* group_b_ptr;
int ldb;
// Check if B is row-major or column-major
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
// Row-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(0); // distance between rows (should be total_N)
} else {
// Column-major B [K, total_N]: slice columns [start_n:end_n]
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
ldb = mat_b.stride(1); // distance between columns (should be K)
}
// Select output slice for group i: Output[:, start_n:end_n]
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
int lda, ldc;
// Row-major A: leading dimension = distance between rows = stride(1)
lda = mat_a.stride(1);
// Output is row-major: leading dimension = distance between rows = stride(0)
ldc = out.stride(0);
gemm_descs.push_back({
static_cast<ck::index_t>(M),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(K),
static_cast<ck::index_t>(lda),
static_cast<ck::index_t>(ldb),
static_cast<ck::index_t>(ldc)
});
p_a_ptrs.push_back(group_a_ptr);
p_b_ptrs.push_back(group_b_ptr);
p_e_ptrs.push_back(group_e_ptr);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
}
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
// Initialize d_ptrs with the correct size
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
static DeviceOp gemm_instance;
auto argument = gemm_instance.MakeArgument(
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
);
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
"CK Group GEMM: argument unsupported (shape/strides/type config)");
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
void* gemm_arg_buf = nullptr;
void* ws_buf = nullptr;
hipMalloc(&gemm_arg_buf, arg_buf_size);
hipMalloc(&ws_buf, ws_size);
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
auto invoker = gemm_instance.MakeInvoker();
hipStream_t stream = c10::hip::getCurrentHIPStream();
invoker.Run(argument, {stream});
hipFree(gemm_arg_buf);
hipFree(ws_buf);
}
void group_gemm_ck(
const at::Tensor& input_a,
const at::Tensor& input_b_colmajor,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& /*bias*/,
at::Tensor& out)
{
// Detect if input_a is row-major based on stride pattern
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
// Ensure tensor A is row-major and contiguous if not already
at::Tensor mat_a = input_a;
if (!a_row_major) {
// If A is not row-major, make it contiguous (row-major)
mat_a = input_a.contiguous();
}
// Force tensor B to be column-major using double transpose trick
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
at::Tensor mat_b = input_b_colmajor;
if (!b_col_major) {
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
}
// For 3D tensors, check the last dimension stride for row-major detection
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
if (mat_a.dtype() == at::kBFloat16) {
// bf16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kHalf) {
// fp16 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
}
} else if (mat_a.dtype() == at::kFloat) {
// fp32 path
if (a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (a_row_major && !b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else if (!a_row_major && b_row_major) {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
} else {
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
}
} else {
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
}
}
} // namespace detail
} // namespace hip
} // namespace at

View File

@ -482,6 +482,7 @@ inductor_core_resources = [
"torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
"torch/csrc/jit/serialization/pickle.cpp",
"torch/csrc/shim_common.cpp",
]
libtorch_core_sources = sorted(

View File

@ -464,25 +464,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
run(g, 64, 8)
self.assertEqual(cnt.frame_count, 2)
def test_dtensor_requires_grad_recompile(self):
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
y = x * x
return y.to_local()
full_x = torch.randn(8, 8, requires_grad=False)
x = distribute_tensor(full_x, mesh, [Shard(0)])
f(x)
full_x = torch.randn(8, 8, requires_grad=True)
x = distribute_tensor(full_x, mesh, [Shard(0)])
f(x)
self.assertEqual(cnt.frame_count, 2)
def test_dtensor_attribute_access_on_intermediate(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -234,6 +234,27 @@ class InPlaceCompilationTests(TestCase):
with self.assertRaises(IndexError):
fn(torch.randn(10), 99)
def test_list_bad_weakref(self):
import weakref
a = torch.Event()
with self.assertRaises(TypeError):
weakref.ref(a)
@torch.compile(backend="eager")
class Mod(torch.nn.Module):
def __init__(self, event):
super().__init__()
self.event = event
def forward(self, x):
return x * int(self.event.query())
e = torch.Event()
m = Mod(e)
a = torch.randn(10)
self.assertEqual(m(a), a)
# The private variants of the below functions are extensively tested
# So as long as the signatures match we're good

View File

@ -1,5 +1,4 @@
# Owner(s): ["module: dynamo"]
import weakref
import torch
import torch._dynamo.test_case
@ -16,14 +15,6 @@ class TestStreams(torch._dynamo.test_case.TestCase):
def tearDownClass(cls):
super().tearDownClass()
def test_stream_weakref(self):
s = torch.Stream()
weakref.ref(s)
def test_event_weakref(self):
e = torch.Event()
weakref.ref(e)
@requires_cuda
def test_run_opcheck(self):
from torch._dynamo.variables.streams import fork_stream, join_stream

View File

@ -1,11 +1,14 @@
# Owner(s): ["module: inductor"]
from unittest import skipIf
from unittest.mock import Mock
import torch
import torch._inductor.metrics as metrics
import torch.utils.flop_counter
from torch._dynamo.utils import counters
from torch._inductor.dependencies import Dep, ReadWrites
from torch._inductor.scheduler import BaseSchedulerNode, Scheduler
from torch._inductor.utils import fresh_inductor_cache
from torch.testing._internal.common_cuda import SM70OrLater
from torch.testing._internal.common_device_type import (
@ -15,6 +18,7 @@ from torch.testing._internal.common_device_type import (
)
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
from torch.testing._internal.inductor_utils import IS_BIG_GPU
from torch.utils._ordered_set import OrderedSet
def FlopCounterMode(*args, **kwargs):
@ -132,6 +136,79 @@ class TestScheduler(TestCase):
counters["inductor"]["flop_count"] = 0
torch._logging.set_logs()
def test_fusion_prevent_too_many_reads_and_writes_prevents_fusion(self):
"""Test that fusion is prevented when unique I/O buffers exceed threshold"""
# Setup: Create nodes with many unique I/O buffers
# node1: reads [A, B, C], writes [D]
# node2: reads [D, E, F], writes [G]
# D becomes internal (node2 reads node1's write)
# After fusion: unique I/O = {A, B, C, E, F, G} = 6 buffers
scheduler = Mock(spec=Scheduler)
scheduler.can_buffer_be_removed_through_fusion = Mock(return_value=False)
node1 = self._create_mock_node(
name="node1", reads=["A", "B", "C"], writes=["D"]
)
node2 = self._create_mock_node(
name="node2", reads=["D", "E", "F"], writes=["G"]
)
# Execute: Check with threshold of 5 (should prevent fusion since 6 > 5)
result = Scheduler.fusion_prevent_too_many_reads_and_writes(
scheduler, node1, node2, threshold=5
)
# Assert: Fusion should be prevented (6 unique buffers > 5 threshold)
self.assertTrue(result)
def test_fusion_prevent_too_many_reads_and_writes_allows_fusion(self):
"""Test that fusion is allowed when intermediate buffers are removed"""
# Setup: Create nodes where node2 reads node1's output
# node1: reads [A, B], writes [C]
# node2: reads [C, D], writes [E]
# C becomes internal (node2 reads node1's write)
# After fusion: unique I/O = {A, B, D, E} = 4 buffers
scheduler = Mock(spec=Scheduler)
scheduler.can_buffer_be_removed_through_fusion = Mock(return_value=False)
node1 = self._create_mock_node(name="node1", reads=["A", "B"], writes=["C"])
node2 = self._create_mock_node(name="node2", reads=["C", "D"], writes=["E"])
# Execute: Check with threshold of 5 (should allow fusion since 4 <= 5)
result = Scheduler.fusion_prevent_too_many_reads_and_writes(
scheduler, node1, node2, threshold=5
)
# Assert: Fusion should be allowed (4 unique buffers <= 5 threshold)
self.assertFalse(result)
def _create_mock_node(self, name: str, reads: list[str], writes: list[str]) -> Mock:
"""Helper method to create a mock scheduler node with specified reads/writes"""
node = Mock(spec=BaseSchedulerNode)
node.get_name = Mock(return_value=name)
node.get_nodes = Mock(return_value=[node])
# Create mock Dep objects for reads and writes
read_deps = OrderedSet()
for read_name in reads:
dep = Mock(spec=Dep)
dep.name = read_name
read_deps.add(dep)
write_deps = OrderedSet()
for write_name in writes:
dep = Mock(spec=Dep)
dep.name = write_name
write_deps.add(dep)
# Create mock ReadWrites object
read_writes = Mock(spec=ReadWrites)
read_writes.reads = read_deps
read_writes.writes = write_deps
node.read_writes = read_writes
return node
instantiate_device_type_tests(TestScheduler, globals())

View File

@ -0,0 +1,91 @@
# Owner(s): ["module: inductor"]
"""
Test selective lowering control via node metadata annotations.
"""
from collections.abc import Callable
import torch
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal.common_utils import instantiate_parametrized_tests
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
@instantiate_parametrized_tests
class SelectiveLoweringTest(InductorTestCase):
"""
Tests for user-controllable selective lowering using node.meta annotations.
"""
device = GPU_TYPE
def _mark_nodes_for_fallback(
self, gm: torch.fx.GraphModule, predicate: Callable[[torch.fx.Node], bool]
) -> torch.fx.GraphModule:
"""
Helper method to mark nodes with should_fallback metadata based on a predicate.
"""
for node in gm.graph.nodes:
if node.op == "call_function" and predicate(node):
node.meta["should_fallback"] = True
return gm
def test_basic_selective_lowering(self):
"""
Test that nodes marked for fallback use fallback handlers instead of lowerings.
"""
def foo(x, y):
a = x + y # This will be marked for fallback
b = a * 2 # This will use normal lowering
return b
x = torch.randn(10, device=self.device)
y = torch.randn(10, device=self.device)
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
# Mark all add operations for fallback
def should_fallback_add(node: torch.fx.Node) -> bool:
return node.target == torch.ops.aten.add.Tensor
self._mark_nodes_for_fallback(gm, should_fallback_add)
from torch._inductor.compile_fx import compile_fx
return compile_fx(gm, example_inputs)
compiled_fn = torch.compile(foo, backend=custom_backend)
result = compiled_fn(x, y)
expected = foo(x, y)
self.assertTrue(torch.allclose(result, expected))
def test_no_fallback_when_unmarked(self):
"""
Test that operations without fallback annotation use normal lowering.
"""
def foo(x, y):
return x + y
x = torch.randn(10, device=self.device)
y = torch.randn(10, device=self.device)
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
# Don't mark anything - all operations should use normal lowering
from torch._inductor.compile_fx import compile_fx
return compile_fx(gm, example_inputs)
compiled_fn = torch.compile(foo, backend=custom_backend)
result = compiled_fn(x, y)
expected = foo(x, y)
self.assertTrue(torch.allclose(result, expected))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests(needs="filelock")

View File

@ -459,6 +459,8 @@ class TestMatmulCuda(InductorTestCase):
@parametrize("b_row_major", [False, True])
@dtypes(torch.bfloat16, torch.float32, torch.float16)
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
self.skipTest("failed using hipblaslt on rocm 6.4.2")
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4

View File

@ -2598,7 +2598,7 @@ class TestTorchDeviceType(TestCase):
dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
y = x.clone()
x.requires_grad = True
d = torch.cdist(x, y)
d = torch.cdist(x, y, p=p)
d.backward(dist_grad)
# Check that the backward pass does not contain invalid
# values such as nan or inf

View File

@ -5,12 +5,11 @@ from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.varlen import varlen_attn
from torch.nn.attention import varlen_attn
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import parametrize, run_tests
from torch.utils._python_dispatch import TorchDispatchMode
VarlenShape = namedtuple(
@ -24,18 +23,6 @@ default_tolerances = {
}
class OpLoggingMode(TorchDispatchMode):
"""Logging mode that captures all dispatched operations"""
def __init__(self):
self.called_ops = []
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
op_name = str(func)
self.called_ops.append(op_name)
return func(*args, **(kwargs or {}))
class AttentionBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
@ -52,9 +39,12 @@ class AttentionBlock(nn.Module):
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
)
def get_varlen_qkv(
def forward_varlen(
self,
x_packed: torch.Tensor,
cu_seq: torch.Tensor,
max_len: int,
is_causal: bool = False,
):
qkv = self.qkv_proj(x_packed)
q, k, v = qkv.chunk(3, dim=-1)
@ -63,51 +53,24 @@ class AttentionBlock(nn.Module):
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
return q, k, v
def forward_varlen(
self,
x_packed: torch.Tensor,
cu_seq: torch.Tensor,
max_len: int,
is_causal: bool = False,
):
q, k, v = self.get_varlen_qkv(x_packed)
attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal)
attn_out = varlen_attn(
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
)
attn_out = attn_out.view(-1, self.embed_dim)
return self.out_proj(attn_out)
def forward_sdpa(
self,
x_padded: torch.Tensor,
seq_lengths: torch.Tensor,
dtype: torch.dtype,
is_causal: bool = False,
):
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
batch_size, seq_len, _ = x_padded.shape
qkv = self.qkv_proj(x_padded)
q, k, v = qkv.chunk(3, dim=-1)
mask = (
torch.arange(seq_len, device=x_padded.device)[None, :]
< seq_lengths[:, None]
)
attn_mask = mask[:, None, None, :].expand(
batch_size, self.num_heads, seq_len, seq_len
)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=is_causal
)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
attn_out = (
attn_out.transpose(1, 2)
.contiguous()
@ -128,9 +91,7 @@ def create_variable_length_batch(
seq_lengths = torch.tensor(seq_lengths, device=device)
total_tokens = seq_lengths.sum().item()
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
)
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
cu_seq[1:] = seq_lengths.cumsum(0)
@ -145,7 +106,6 @@ def create_variable_length_batch(
end_idx = start_idx + seq_len
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
start_idx = end_idx
x_padded = x_padded.clone().detach().requires_grad_()
return {
"seq_lengths": seq_lengths,
@ -173,11 +133,7 @@ class TestVarlenAttention(NNTestCase):
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
requires_grad=True,
total_tokens, shape.embed_dim, device=device, dtype=dtype
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
@ -191,128 +147,6 @@ class TestVarlenAttention(NNTestCase):
self.assertEqual(output.device, torch.device(device))
self.assertEqual(output.dtype, dtype)
varlen_grad_out = torch.ones_like(output)
varlen_grad = torch.autograd.grad(
outputs=output,
inputs=x_packed,
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
self.assertIsNotNone(varlen_grad)
self.assertEqual(varlen_grad.shape, x_packed.shape)
self.assertEqual(varlen_grad.dtype, x_packed.dtype)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_custom_op_compliance(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
q, k, v = attention_block.get_varlen_qkv(x_packed)
torch.library.opcheck(
torch.ops.torch_attn._varlen_attn,
(q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False),
)
out, lse, rng_state = torch.ops.torch_attn._varlen_attn(
q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False
)
grad_out = torch.randn_like(out)
# we don't support double backward
# skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static
torch.library.opcheck(
torch.ops.torch_attn._varlen_attn_backward,
(
grad_out,
q,
k,
v,
out,
lse,
cu_seq,
cu_seq,
shape.max_seq_len,
shape.max_seq_len,
False,
rng_state,
),
test_utils=["test_schema", "test_faketensor"],
)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_custom_op_registration(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
compiled_forward = torch.compile(
attention_block.forward_varlen, backend="eager", fullgraph=True
)
with OpLoggingMode() as mode:
output = compiled_forward(
x_packed, cu_seq, shape.max_seq_len, is_causal=False
)
varlen_grad_out = torch.ones_like(output)
_ = torch.autograd.grad(
outputs=output,
inputs=x_packed,
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
called_ops = mode.called_ops
custom_ops_called = any(
"torch_attn._varlen_attn" in op for op in called_ops
) and any("torch_attn._varlen_attn_backward" in op for op in called_ops)
assert custom_ops_called
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@ -338,10 +172,7 @@ class TestVarlenAttention(NNTestCase):
is_causal=is_causal,
)
sdpa_output = attention_block.forward_sdpa(
variable_length_batch_data["x_padded"],
variable_length_batch_data["seq_lengths"],
dtype=dtype,
is_causal=is_causal,
variable_length_batch_data["x_padded"], is_causal=is_causal
)
tolerances = default_tolerances[dtype]
@ -355,44 +186,6 @@ class TestVarlenAttention(NNTestCase):
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
start_idx = end_idx
varlen_grad_out = torch.ones_like(varlen_output)
sdpa_grad_out = torch.zeros_like(sdpa_output)
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx]
start_idx = end_idx
varlen_grad = torch.autograd.grad(
outputs=varlen_output,
inputs=variable_length_batch_data["x_packed"],
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
sdpa_grad = torch.autograd.grad(
outputs=sdpa_output,
inputs=variable_length_batch_data["x_padded"],
grad_outputs=sdpa_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
varlen_grad_seq = varlen_grad[start_idx:end_idx]
sdpa_grad_seq = sdpa_grad[i, :seq_len]
torch.testing.assert_close(varlen_grad_seq, sdpa_grad_seq, **tolerances)
start_idx = end_idx
device_types = ("cuda",)

View File

@ -374,6 +374,22 @@ def build_collectives(
return tracebacks, collectives, nccl_calls
def transform_ft(
details: dict[str, dict[str, Any]], group_world_size: int
) -> dict[str, dict[str, Any]]:
for dump_key, dump in details.items():
rank = dump["rank"]
for key, pg_config in dump["pg_config"].items():
if pg_config["desc"] == "default_pg":
ranks = eval(pg_config["ranks"])
replica_id = rank // group_world_size
first_rank = replica_id * group_world_size
new_ranks = [r + first_rank for r in ranks]
details[dump_key]["pg_config"][key]["ranks"] = f"{new_ranks}"
return details
def build_db(
details: dict[str, dict[str, Any]], args: argparse.Namespace, version: str
) -> Database:

View File

@ -74,6 +74,17 @@ class JobConfig:
default=10,
help="Maximum number of mismatches we print (from earliest).",
)
self.parser.add_argument(
"--transform-ft",
action="store_true",
help="Transform PG config to use global ranks to analyze traces produced by torchft",
)
self.parser.add_argument(
"--group-world-size",
type=int,
default=None,
help="The number of ranks in 1 torchft replica group. Must be specified if --transform-ft is True",
)
def parse_args(
self: "JobConfig", args: Optional[Sequence[str]]

View File

@ -32,7 +32,7 @@ import pickle
from collections.abc import Sequence
from typing import Optional
from tools.flight_recorder.components.builder import build_db
from tools.flight_recorder.components.builder import build_db, transform_ft
from tools.flight_recorder.components.config_manager import JobConfig
from tools.flight_recorder.components.loader import read_dir
from tools.flight_recorder.components.types import types
@ -46,6 +46,9 @@ def main(args: Optional[Sequence[str]] = None) -> None:
assert args.trace_dir, "Trace directory trace_dir is required"
# pyrefly: ignore [bad-argument-type]
details, version = read_dir(args)
if args.transform_ft:
assert args.group_world_size, "World size is required for transform_ft"
details = transform_ft(details, args.group_world_size)
# pyrefly: ignore [bad-argument-type]
db = build_db(details, args, version)
# pyrefly: ignore [missing-attribute]

View File

@ -2150,19 +2150,6 @@ class GuardBuilder(GuardBuilderBase):
metadata_checker, get_verbose_code_parts(global_name, guard)
)
def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None:
# Copied from DTensor __metadata_guard__
# TODO - Consider moving this to C++ if stable
value = deepcopy(self.get(guard.name))
def guard_fn(x: Any) -> bool:
return x._check_equals(value, skip_shapes=True)
code = f"__dtensor_spec_{id(guard_fn)}"
self.get_guard_manager(guard).add_lambda_guard(
guard_fn, get_verbose_code_parts(code, guard)
)
def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
ref = self.arg_ref(guard)
val = self.get(guard.name)

View File

@ -346,10 +346,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)
def _is_pytreespec_instance(
obj: Any, /
) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]:
return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec))
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type]
optree.treespec_leaf,
@ -552,7 +550,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return treespec.unflatten(leaves)

View File

@ -2229,70 +2229,25 @@ class VariableBuilder:
if isinstance(source, GradSource) and is_from_optimizer_source(source):
guard_type = GuardBuilder.NOT_NONE_MATCH
is_dtensor = torch.distributed.is_available() and isinstance(
value, torch.distributed.tensor.DTensor
)
if not is_dtensor:
# We guard on the _local_tensor and the _spec, and therefore we dont
# have to guard on the outer DTensor.
self.install_guards(
functools.partial(
guard_type,
value=(
value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value)
),
)
self.install_guards(
functools.partial(
guard_type,
value=(
value
if isinstance(source, NumpyTensorSource)
else TensorWeakRef(value)
),
)
)
# We install TYPE_MATCH guards for traceable wrapper subclass object,
# and recursively install corresponding guard for each inner attribute.
if is_traceable_wrapper_subclass(value):
# Tensor subclass guards are very expensive because they are
# implemented in Python. Since DTensor is PyTorch-maintained class,
# we can skip a lot of these guards.
if is_dtensor:
self.install_guards(GuardBuilder.TYPE_MATCH)
# The inner tensor name is always _local_tensor. If its not, we
# raise assertion to update the check accordingly.
inner_tensor_name = value.__tensor_flatten__()[0][0]
if inner_tensor_name != "_local_tensor":
raise RuntimeError(
"Expecting Dtensor inner tensor name to be _local_tensor"
)
# Now selectively guard on the flattening context
flattening_ctx = value.__tensor_flatten__()[1]
# This is supposed to be (self._spec, self.requires_grad)
if not (
len(flattening_ctx) == 2
and flattening_ctx[0] == value._spec
and flattening_ctx[1] == value.requires_grad
):
# If not, raise an assertion to update to the new guards
raise RuntimeError(
"Expecting Dtensor flattening ctx to be _spec, requires_grad"
)
# Guard on the dtensor spec
install_guard(
AttrSource(self.source, "_spec").make_guard(
GuardBuilder.DTENSOR_SPEC_MATCH
)
)
# Move this to C++
install_guard(
AttrSource(self.source, "requires_grad").make_guard(
GuardBuilder.EQUALS_MATCH
)
)
else:
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
self.install_guards(GuardBuilder.TYPE_MATCH)
install_guard(
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
)
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
self.install_guards(GuardBuilder.TYPE_MATCH)
install_guard(
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
)
attrs, _ = value.__tensor_flatten__()
for attr in attrs:

View File

@ -530,6 +530,17 @@ class InductorChoices:
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
return False
if (
config.max_fusion_unique_io_buffers is not None
and scheduler.fusion_prevent_too_many_reads_and_writes(
node1,
node2,
config.max_fusion_unique_io_buffers,
)
):
WhyNoFuse(node1, node2)("fusion_prevent_too_many_reads_and_writes")
return False
return True
@staticmethod

View File

@ -688,6 +688,10 @@ max_fusion_size = 64
# how many nodes to attempt pairwise fusion with in a buffer group
max_fusion_buffer_group_pairwise_attempts = 64
# maximum number of unique input/output buffers allowed in fused kernels.
# The check is disabled if set to None.
max_fusion_unique_io_buffers: Optional[int] = None
# max number of inputs to generate cat as a pointwise op with masked loads
max_pointwise_cat_inputs = 8

View File

@ -1322,7 +1322,12 @@ class GraphLowering(torch.fx.Interpreter):
else:
args, kwargs = layout_constraints(n, *args, **kwargs)
out = lowerings[target](*args, **kwargs) # type: ignore[index]
if "should_fallback" in n.meta:
out = fallback_handler(target, add_to_fallback_set=False)(
*args, **kwargs
)
else:
out = lowerings[target](*args, **kwargs) # type: ignore[index]
if layout_constraints:
# layout_constraints are allowed to make new copies of the inputs.

View File

@ -4113,6 +4113,54 @@ class Scheduler:
return True
return False
def fusion_prevent_too_many_reads_and_writes(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int
) -> bool:
# After fusion, we need to calculate the unique I/O buffers
# accounting for buffers that become internal (removed through fusion)
# Get all nodes that will be in the fused node
fused_node_names = OrderedSet(
[node.get_name() for node in node1.get_nodes()]
+ [node.get_name() for node in node2.get_nodes()]
)
# Calculate node2 reads that can be removed through fusion,
# i.e. node2 reads that are outputs of node1
node1_write_names = OrderedSet(dep.name for dep in node1.read_writes.writes)
node2_read_names = OrderedSet(dep.name for dep in node2.read_writes.reads)
reads_removed_through_fusion = node2_read_names & node1_write_names
# Calculate node1 writes that can be removed through fusion,
# i.e. node1 writes that are only read by node2
writes_removed_through_fusion: OrderedSet[str] = OrderedSet()
for write_dep in node1.read_writes.writes:
if self.can_buffer_be_removed_through_fusion(
write_dep.name, fused_node_names
):
writes_removed_through_fusion.add(write_dep.name)
# Get all unique reads (union of both nodes' reads)
all_read_names = OrderedSet(
dep.name for dep in node1.read_writes.reads
) | OrderedSet(dep.name for dep in node2.read_writes.reads)
# Get all unique writes (union of both nodes' writes)
all_write_names = OrderedSet(
dep.name for dep in node1.read_writes.writes
) | OrderedSet(dep.name for dep in node2.read_writes.writes)
# Remove reads that become internal
unique_reads = all_read_names - reads_removed_through_fusion
# Remove writes that become internal
unique_writes = all_write_names - writes_removed_through_fusion
# Get all unique buffer names (reads and writes combined, but no double counting)
unique_io_buffers = unique_reads | unique_writes
return len(unique_io_buffers) > threshold
def are_long_distant_nodes(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:

View File

@ -49,7 +49,6 @@ static PyObject* THPEvent_pynew(
}
THPEvent* self = reinterpret_cast<THPEvent*>(ptr.get());
self->weakreflist = nullptr;
// TODO: blocking and interprocess are not supported yet. To support them, the
// flag system of c10::Event needs to be refactored. C10::Event should also
@ -74,7 +73,6 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
TORCH_CHECK(self, "Failed to allocate memory for Event");
auto self_ = reinterpret_cast<THPEvent*>(self.get());
self_->weakreflist = nullptr;
new (&self_->event) c10::Event(device_type, flag);
return self.release();
}
@ -84,7 +82,6 @@ static void THPEvent_dealloc(THPEvent* self) {
pybind11::gil_scoped_release no_gil{};
self->event.~Event();
}
PyObject_ClearWeakRefs((PyObject*)self);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@ -285,8 +282,7 @@ static PyMethodDef THPEvent_methods[] = {
{"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
{"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
{nullptr}};
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Winvalid-offsetof"
PyTypeObject THPEventType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch.Event", /* tp_name */
@ -312,7 +308,7 @@ PyTypeObject THPEventType = {
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
THPEvent_methods, /* tp_methods */
@ -327,7 +323,6 @@ PyTypeObject THPEventType = {
nullptr, /* tp_alloc */
THPEvent_pynew, /* tp_new */
};
#pragma GCC diagnostic pop
void THPEvent_init(PyObject* module) {
THPEventClass = &THPEventType;

View File

@ -7,7 +7,6 @@
struct TORCH_API THPEvent {
PyObject_HEAD
c10::Event event;
PyObject* weakreflist;
};
TORCH_API extern PyTypeObject* THPEventClass;
TORCH_API extern PyTypeObject THPEventType;

View File

@ -95,7 +95,6 @@ static PyObject* THPStream_pynew(
self->device_index = static_cast<int64_t>(stream_opt->device_index());
self->device_type = static_cast<int64_t>(stream_opt->device_type());
self->context = nullptr;
self->weakreflist = nullptr;
return static_cast<PyObject*>(ptr.release());
END_HANDLE_TH_ERRORS
@ -115,13 +114,11 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
self->device_index = static_cast<int64_t>(stream.device_index());
self->device_type = static_cast<int64_t>(stream.device_type());
self->context = nullptr;
self->weakreflist = nullptr;
return ptr.release();
END_HANDLE_TH_ERRORS
}
static void THPStream_dealloc(THPStream* self) {
PyObject_ClearWeakRefs((PyObject*)self);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@ -447,7 +444,7 @@ static PyTypeObject THPStreamType = {
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
THPStream_richcompare, /* tp_richcompare */
offsetof(THPStream, weakreflist), /* tp_weaklistoffset */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
// NOLINTNEXTLINE(*const-cast)

View File

@ -13,7 +13,6 @@ struct THPStream {
int64_t device_index;
// Used to switch stream context management, initialized lazily.
PyObject* context;
PyObject* weakreflist;
};
extern TORCH_API PyTypeObject* THPStreamClass;

View File

@ -1406,169 +1406,6 @@ AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
});
}
static StableIValue from_ivalue(
const c10::TypePtr& type,
const c10::IValue& ivalue) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
return torch::stable::detail::from(ath);
}
case c10::TypeKind::IntType: {
return torch::stable::detail::from(ivalue.toInt());
}
case c10::TypeKind::FloatType: {
return torch::stable::detail::from(ivalue.toDouble());
}
case c10::TypeKind::BoolType: {
return torch::stable::detail::from(ivalue.toBool());
}
case c10::TypeKind::ScalarTypeType: {
return torch::stable::detail::from(ivalue.toScalarType());
}
case c10::TypeKind::DeviceObjType: {
return torch::stable::detail::from(ivalue.toDevice());
}
case c10::TypeKind::LayoutType: {
return torch::stable::detail::from(ivalue.toLayout());
}
case c10::TypeKind::MemoryFormatType: {
return torch::stable::detail::from(ivalue.toMemoryFormat());
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
// ideally, if we had the C++ type corresponding to inner_type, which we
// will denote as inner_type::t (does not actually exist), we would be
// able to follow the patterned semantic of every other case here in one
// line:
//
// return
// torch::stable::detail::from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with torch::stable::detail::from<std::optional<T>>
// function in torch/csrc/stable/stableivalue_conversions.h
if (ivalue.isNone()) {
return torch::stable::detail::from(std::nullopt);
}
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
return torch::stable::detail::from(sivp);
}
default: {
TORCH_CHECK(
false,
"Not yet supported conversion from IValue to StableIValue for schema type: ",
type->str());
}
}
}
static c10::IValue to_ivalue(
const c10::TypePtr& type,
const StableIValue stable_ivalue) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get())));
}
case c10::TypeKind::IntType: {
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
}
case c10::TypeKind::FloatType: {
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
}
case c10::TypeKind::BoolType: {
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
}
case c10::TypeKind::ScalarTypeType: {
return c10::IValue(
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
}
case c10::TypeKind::DeviceObjType: {
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
}
case c10::TypeKind::LayoutType: {
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
}
case c10::TypeKind::MemoryFormatType: {
return c10::IValue(
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
// ideally, if we had the C++ type corresponding to inner_type, which we
// will denote as inner_type::t (does not actually exist), we would be
// able to follow the patterned semantic of every other case here in one
// line:
//
// return
// c10::IValue(torch::stable::detail::to<std::optional<inner_type::t>>(stable_ivalue));
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with the torch::stable::detail::to<T> function in
// torch/csrc/stable/stableivalue_conversions.h
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
return c10::IValue();
}
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
auto ival = to_ivalue(inner_type, *sivp);
delete sivp;
return ival;
}
default: {
TORCH_CHECK(
false,
"Not yet supported conversion from StableIValue to IValue for schema type: ",
type->str());
}
}
}
class StableIValueBoxedKernel : public c10::OperatorKernel {
public:
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
: fn_(fn) {}
void operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
auto ministack =
std::make_unique<StableIValue[]>(std::max(num_arguments, num_returns));
for (const auto idx : c10::irange(num_arguments)) {
const auto ministack_idx = num_arguments - idx - 1;
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
}
// boxed function is going to take a stack of StableIValues, cast them to
// our schema values, and run the function and modify the StableIValue stack
fn_(ministack.get(), num_arguments, num_returns);
// read the output from the end of the stack and wrap that back into
// IValue from StableIValue
for (size_t idx = 0; idx < num_returns; idx++) {
const c10::TypePtr& ret_type = schema.returns()[idx].type();
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
}
}
private:
void (*fn_)(StableIValue*, uint64_t, uint64_t);
};
AOTITorchError aoti_torch_library_init_impl(
const char* ns,
const char* k,
@ -1618,18 +1455,6 @@ AOTITorchError aoti_torch_library_init_fragment(
});
}
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
TorchLibraryHandle self,
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
reinterpret_cast<torch::Library*>(self)->impl(
name,
torch::CppFunction::makeFromBoxedFunctor(
std::make_unique<StableIValueBoxedKernel>(fn)));
});
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_library_def(TorchLibraryHandle self, const char* name) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
@ -1642,40 +1467,6 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
{ delete reinterpret_cast<torch::Library*>(tlh); });
}
AOTITorchError aoti_torch_call_dispatcher(
const char* opName,
const char* overloadName,
StableIValue* stack) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
const auto op =
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
torch::jit::Stack ivalue_stack;
// we will only need max(num_args, num_returns)
ivalue_stack.reserve(std::max(num_arguments, num_returns));
// convert StableIValue stack to c10::IValue stack
for (const auto idx : c10::irange(num_arguments)) {
auto stable_ivalue = stack[idx];
auto arg_type = schema.arguments()[idx].type();
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
}
op.callBoxed(ivalue_stack);
// there should then be num_returns IValues on the stack, which
// we will convert to StableIValue and repopulate user input stack
for (const auto idx : c10::irange(num_returns)) {
const auto stack_idx = num_returns - idx - 1;
const c10::TypePtr& ret_type = schema.returns()[idx].type();
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
}
});
}
AOTITorchError aoti_torch_create_device_guard(
int32_t device_index,
DeviceGuardHandle* ret_guard // returns new reference

View File

@ -260,82 +260,20 @@ typedef __half half;
)";
#endif
#if defined(USE_ROCM) && ROCM_VERSION < 70000
#if defined(USE_ROCM)
#if ROCM_VERSION >= 70000
#define BF16_UINT32_DEF "typedef unsigned int uint32_t;\n"
#else
#define BF16_UINT32_DEF ""
#endif
constexpr auto bfloat16_support_literal =
R"(
#ifndef __align__
#define __align__(x) __attribute__((aligned(x)))
#endif
typedef struct __align__(2) {
unsigned short x;
}
__nv_bfloat16_raw;
#if defined(__cplusplus)
struct __align__(2) __nv_bfloat16 {
__host__ __device__ __nv_bfloat16() {}
__host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
__x = hr.x;
return *this;
}
unsigned short __x;
};
__device__ unsigned short __internal_float2bfloat16(
const float f,
unsigned int& sign,
unsigned int& remainder) {
unsigned int x;
x = __float_as_uint(f);
if ((x & 0x7fffffffU) > 0x7f800000U) {
sign = 0U;
remainder = 0U;
return static_cast<unsigned short>(0x7fffU);
}
sign = x >> 31;
remainder = x << 16;
return static_cast<unsigned short>(x >> 16);
}
/* Definitions of intrinsics */
__device__ __nv_bfloat16 __float2bfloat16(const float a) {
__nv_bfloat16 val;
__nv_bfloat16_raw r;
unsigned int sign;
unsigned int remainder;
r.x = __internal_float2bfloat16(a, sign, remainder);
if ((remainder > 0x80000000U) ||
((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
r.x++;
}
val = r;
return val;
}
__device__ float __bfloat162float(const __nv_bfloat16 a) {
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(a.__x) << 16};
return u.fp32;
}
#endif /* defined(__cplusplus) */
)";
#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
constexpr auto bfloat16_support_literal =
R"(
#ifndef __align__
#define __align__(x) __attribute__((aligned(x)))
#endif
typedef unsigned int uint32_t;
)" BF16_UINT32_DEF R"(
typedef struct __align__(2) {
unsigned short x;
}

417
torch/csrc/shim_common.cpp Normal file
View File

@ -0,0 +1,417 @@
#include <c10/core/DispatchKey.h>
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/stable/library.h>
#include <torch/library.h>
#include <torch/csrc/stable/c/shim.h>
static StableIValue from_ivalue(
const c10::TypePtr& type,
const c10::IValue& ivalue,
uint64_t extension_build_version) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
return torch::stable::detail::_from(ath, extension_build_version);
}
case c10::TypeKind::IntType: {
return torch::stable::detail::_from(
ivalue.toInt(), extension_build_version);
}
case c10::TypeKind::FloatType: {
return torch::stable::detail::_from(
ivalue.toDouble(), extension_build_version);
}
case c10::TypeKind::BoolType: {
return torch::stable::detail::_from(
ivalue.toBool(), extension_build_version);
}
case c10::TypeKind::ScalarTypeType: {
return torch::stable::detail::_from(
ivalue.toScalarType(), extension_build_version);
}
case c10::TypeKind::DeviceObjType: {
return torch::stable::detail::_from(
ivalue.toDevice(), extension_build_version);
}
case c10::TypeKind::LayoutType: {
return torch::stable::detail::_from(
ivalue.toLayout(), extension_build_version);
}
case c10::TypeKind::MemoryFormatType: {
return torch::stable::detail::_from(
ivalue.toMemoryFormat(), extension_build_version);
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
// ideally, if we had the C++ type corresponding to inner_type, which we
// will denote as inner_type::t (does not actually exist), we would be
// able to follow the patterned semantic of every other case here in one
// line:
//
// return
// torch::stable::detail::from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with torch::stable::detail::from<std::optional<T>>
// function in torch/csrc/stable/stableivalue_conversions.h
if (ivalue.isNone()) {
return torch::stable::detail::_from(
std::nullopt, extension_build_version);
}
StableIValue* sivp = new StableIValue(
from_ivalue(inner_type, ivalue, extension_build_version));
return torch::stable::detail::_from(sivp, extension_build_version);
}
default: {
TORCH_CHECK(
false,
"Not yet supported conversion from IValue to StableIValue for schema type: ",
type->str());
}
}
}
static c10::IValue to_ivalue(
const c10::TypePtr& type,
const StableIValue stable_ivalue,
uint64_t extension_build_version) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
torch::stable::detail::_to<AtenTensorHandle>(
stable_ivalue, extension_build_version));
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get())));
}
case c10::TypeKind::IntType: {
return c10::IValue(torch::stable::detail::_to<int64_t>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::FloatType: {
return c10::IValue(torch::stable::detail::_to<double>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::BoolType: {
return c10::IValue(torch::stable::detail::_to<bool>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::ScalarTypeType: {
return c10::IValue(torch::stable::detail::_to<c10::ScalarType>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::DeviceObjType: {
return c10::IValue(torch::stable::detail::_to<c10::Device>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::LayoutType: {
return c10::IValue(torch::stable::detail::_to<c10::Layout>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::MemoryFormatType: {
return c10::IValue(torch::stable::detail::_to<c10::MemoryFormat>(
stable_ivalue, extension_build_version));
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
// ideally, if we had the C++ type corresponding to inner_type, which we
// will denote as inner_type::t (does not actually exist), we would be
// able to follow the patterned semantic of every other case here in one
// line:
//
// return
// c10::IValue(torch::stable::detail::to<std::optional<inner_type::t>>(stable_ivalue));
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with the torch::stable::detail::_to<T> function in
// torch/csrc/stable/library.h
if (stable_ivalue ==
torch::stable::detail::_from(std::nullopt, extension_build_version)) {
return c10::IValue();
}
auto sivp = torch::stable::detail::_to<StableIValue*>(
stable_ivalue, extension_build_version);
auto ival = to_ivalue(inner_type, *sivp, extension_build_version);
delete sivp;
return ival;
}
default: {
TORCH_CHECK(
false,
"Not yet supported conversion from StableIValue to IValue for schema type: ",
type->str());
}
}
}
class StableIValueBoxedKernel : public c10::OperatorKernel {
public:
StableIValueBoxedKernel(
void (*fn)(StableIValue*, uint64_t, uint64_t),
uint64_t extension_build_version)
: fn_(fn), extension_build_version_(extension_build_version) {}
void operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
auto ministack =
std::make_unique<StableIValue[]>(std::max(num_arguments, num_returns));
for (const auto idx : c10::irange(num_arguments)) {
const auto ministack_idx = num_arguments - idx - 1;
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
ministack[ministack_idx] = from_ivalue(
arg_type, torch::jit::pop(stack), extension_build_version_);
}
// boxed function is going to take a stack of StableIValues, cast them to
// our schema values, and run the function and modify the StableIValue stack
fn_(ministack.get(), num_arguments, num_returns);
// read the output from the end of the stack and wrap that back into
// IValue from StableIValue
for (size_t idx = 0; idx < num_returns; idx++) {
const c10::TypePtr& ret_type = schema.returns()[idx].type();
torch::jit::push(
stack, to_ivalue(ret_type, ministack[idx], extension_build_version_));
}
}
private:
void (*fn_)(StableIValue*, uint64_t, uint64_t);
uint64_t extension_build_version_;
};
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
TorchLibraryHandle self,
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
reinterpret_cast<torch::Library*>(self)->impl(
name,
torch::CppFunction::makeFromBoxedFunctor(
std::make_unique<StableIValueBoxedKernel>(fn, TORCH_ABI_VERSION)));
});
}
// Version-aware variant of aoti_torch_library_impl that takes an
// extension_build_version parameter for backward compatibility
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
TorchLibraryHandle self,
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t),
uint64_t extension_build_version) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
reinterpret_cast<torch::Library*>(self)->impl(
name,
torch::CppFunction::makeFromBoxedFunctor(
std::make_unique<StableIValueBoxedKernel>(
fn, extension_build_version)));
});
}
AOTITorchError aoti_torch_call_dispatcher(
const char* opName,
const char* overloadName,
StableIValue* stack) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
const auto op =
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
torch::jit::Stack ivalue_stack;
// we will only need max(num_args, num_returns)
ivalue_stack.reserve(std::max(num_arguments, num_returns));
// convert StableIValue stack to c10::IValue stack
for (const auto idx : c10::irange(num_arguments)) {
auto stable_ivalue = stack[idx];
auto arg_type = schema.arguments()[idx].type();
torch::jit::push(
ivalue_stack, to_ivalue(arg_type, stable_ivalue, TORCH_ABI_VERSION));
}
op.callBoxed(ivalue_stack);
// there should then be num_returns IValues on the stack, which
// we will convert to StableIValue and repopulate user input stack
for (const auto idx : c10::irange(num_returns)) {
const auto stack_idx = num_returns - idx - 1;
const c10::TypePtr& ret_type = schema.returns()[idx].type();
stack[stack_idx] = from_ivalue(
ret_type, torch::jit::pop(ivalue_stack), TORCH_ABI_VERSION);
}
});
}
// Schema Adapter Infrastructure
// SchemaAdapterRegistry contains the adapters registered via
// register_schema_adapter that define how to convert the StableIValue argument
// stack to an IValue stack when changes are made to the schema of an ATen
// function. This should only be relevant in the context of calling
// torch_call_dispatcher.
// Currently this only adapts the argument stack.
// C++ default argument resolution will happen at compile time in the
// torch/csrc/stable/ops.h header, so extensions always pass complete argument
// lists for the version they build against's schema. As such, this is only
// needed if a new argument is added to the schema
//
// This is not declared in the stable shim.h,
// so we **do not make any guarantees that the signature of this will not
// change**. If there is a need to define similar infrastructure for the returns
// of an aten function we can update this.
namespace {
using SchemaAdapterFn = std::function<torch::jit::Stack(
const c10::FunctionSchema& current_schema,
const StableIValue* extension_stack,
uint64_t extension_build_version)>;
// Global registry for schema adapters
class SchemaAdapterRegistry {
private:
std::unordered_map<
std::string,
std::vector<std::pair<uint64_t, SchemaAdapterFn>>>
adapters_;
public:
static SchemaAdapterRegistry& instance() {
static SchemaAdapterRegistry registry;
return registry;
}
void register_adapter(
const std::string& op_name,
uint64_t
applies_to_versions_below, // versions below this need the adapter
SchemaAdapterFn adapter) {
adapters_[op_name].emplace_back(applies_to_versions_below, adapter);
// Sort by version ascending - this allows us to find the first (most
// specific) match
std::sort(
adapters_[op_name].begin(),
adapters_[op_name].end(),
[](const auto& a, const auto& b) { return a.first < b.first; });
}
std::optional<SchemaAdapterFn> get_adapter(
const std::string& op_name,
uint64_t extension_version) {
auto it = adapters_.find(op_name);
if (it == adapters_.end())
return std::nullopt;
// Find the first adapter that applies (most specific due to ascending sort)
for (const auto& [applies_to_versions_below, adapter] : it->second) {
if (extension_version < applies_to_versions_below) {
return adapter;
}
}
return std::nullopt;
}
};
// Internal API for registering adapters that define how to convert the
// StableIValue **argument** stack to an IValue stack when changes are
// made to the schema of a function. adapter_fn will be used if
// extension_build_version < applies_to_versions_below.
[[maybe_unused]] AOTITorchError register_schema_adapter(
const char* op_name,
uint64_t applies_to_versions_below,
SchemaAdapterFn adapter_fn) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
auto& registry = SchemaAdapterRegistry::instance();
registry.register_adapter(
std::string(op_name), applies_to_versions_below, std::move(adapter_fn));
});
}
} // namespace
// Function to register test schema adapters for _test_schema_upgrader
// This demonstrates the adapter registration pattern (internal use only)
static AOTITorchError _register_adapters() {
// ** Schema adapters should be registered here**
// Refer to https://github.com/pytorch/pytorch/pull/165284/ for an example.
//
// if (auto err = register_schema_adapter(
// "aten::your_op",
// VERSION_FOO, // applies to versions < VERSION_FOO
// adapt_v1_to_vfoo)) {
// return err;
// }
return AOTI_TORCH_SUCCESS;
}
// Static initialization to automatically register test adapters
static struct AdapterInitializer {
AdapterInitializer() {
// Register the test adapters when the library loads
_register_adapters();
}
} adapter_initializer;
AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
const char* opName,
const char* overloadName,
StableIValue* stack,
// version of stable headers used to build the extension: necessary for
// applying schema adapters
uint64_t extension_build_version) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
const auto op =
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
torch::jit::Stack ivalue_stack;
auto& registry = SchemaAdapterRegistry::instance();
// Check if we need an adapter for this operation
if (auto adapter = registry.get_adapter(opName, extension_build_version)) {
// Use adapter to create IValue stack
ivalue_stack = (*adapter)(schema, stack, extension_build_version);
} else {
// No adapter needed - implementation matches aoti_torch_call_dispatcher
ivalue_stack.reserve(std::max(num_arguments, num_returns));
for (const auto idx : c10::irange(num_arguments)) {
auto stable_ivalue = stack[idx];
auto arg_type = schema.arguments()[idx].type();
torch::jit::push(
ivalue_stack,
to_ivalue(arg_type, stable_ivalue, extension_build_version));
}
}
op.callBoxed(ivalue_stack);
// there should then be num_returns IValues on the stack, which
// we will convert to StableIValue and repopulate user input stack
for (const auto idx : c10::irange(num_returns)) {
const auto stack_idx = num_returns - idx - 1;
const c10::TypePtr& ret_type = schema.returns()[idx].type();
stack[stack_idx] = from_ivalue(
ret_type, torch::jit::pop(ivalue_stack), extension_build_version);
}
});
}

View File

@ -0,0 +1,46 @@
#ifndef STABLE_TORCH_SHIM
#define STABLE_TORCH_SHIM
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/version.h>
// This header defines stable C API extensions for backward/forward
// compatibility when calling ATen operations through the dispatcher.
//
// This is separate from the main AOTI shim to provide versioning capabilities
// for schema changes in native ATen functions.
#ifdef __cplusplus
extern "C" {
#endif
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
using StableIValue = uint64_t;
// Has the same semantic as aoti_torch_call_dispatcher, but takes an
// additional argument for the extension build version. This is
// needed for backward compatibility when calling native functions via
// the dispatcher. The caller should pass in the libtorch version the
// extension is building with (NOT target version).
AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
const char* opName,
const char* overloadName,
StableIValue* stack,
uint64_t extension_build_version);
// Version-aware variant of aoti_torch_library_impl that takes an
// extension_build_version parameter for backward compatibility
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
TorchLibraryHandle self,
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t),
uint64_t extension_build_version);
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
#ifdef __cplusplus
} // extern "C"
#endif
#endif // STABLE_TORCH_SHIM

View File

@ -4,12 +4,14 @@
// code for better UX.
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/headeronly/macros/Macros.h>
// Technically, this file doesn't use anything from stableivalue_conversions.h,
// but we need to include it here as the contents of stableivalue_conversions.h
// used to live here and so we need to expose them for backwards compatibility.
#include <torch/csrc/stable/stableivalue_conversions.h>
#include <torch/csrc/stable/version.h>
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
@ -81,7 +83,11 @@ class StableLibrary final {
StableLibrary& impl(
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
torch_library_impl(lib_, name, fn, TORCH_ABI_VERSION);
#else
aoti_torch_library_impl(lib_, name, fn);
#endif
return *this;
}

View File

@ -8,6 +8,8 @@
#include <vector>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/csrc/stable/version.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
@ -25,8 +27,13 @@ inline torch::stable::Tensor empty_like(const torch::stable::Tensor& self) {
torch::stable::detail::from(std::nullopt),
torch::stable::detail::from(std::nullopt),
torch::stable::detail::from(std::nullopt)};
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::empty_like", "", stack.data(), TORCH_ABI_VERSION));
#else
TORCH_ERROR_CODE_CHECK(
aoti_torch_call_dispatcher("aten::empty_like", "", stack.data()));
#endif
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
@ -201,8 +208,13 @@ inline torch::stable::Tensor transpose(
torch::stable::detail::from(self),
torch::stable::detail::from(dim0),
torch::stable::detail::from(dim1)};
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::transpose", "int", stack.data(), TORCH_ABI_VERSION));
#else
TORCH_ERROR_CODE_CHECK(
aoti_torch_call_dispatcher("aten::transpose", "int", stack.data()));
#endif
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
@ -212,8 +224,13 @@ inline torch::stable::Tensor transpose(
inline torch::stable::Tensor zero_(torch::stable::Tensor& self) {
const auto num_args = 1;
std::array<StableIValue, num_args> stack{torch::stable::detail::from(self)};
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::zero_", "", stack.data(), TORCH_ABI_VERSION));
#else
TORCH_ERROR_CODE_CHECK(
aoti_torch_call_dispatcher("aten::zero_", "", stack.data()));
#endif
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
@ -228,8 +245,13 @@ inline torch::stable::Tensor copy_(
torch::stable::detail::from(self),
torch::stable::detail::from(src),
torch::stable::detail::from(non_blocking.value_or(false))};
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::copy_", "", stack.data(), TORCH_ABI_VERSION));
#else
TORCH_ERROR_CODE_CHECK(
aoti_torch_call_dispatcher("aten::copy_", "", stack.data()));
#endif
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
@ -240,9 +262,20 @@ inline torch::stable::Tensor clone(const torch::stable::Tensor& self) {
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self),
torch::stable::detail::from(std::nullopt)};
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::clone", "", stack.data(), TORCH_ABI_VERSION));
#else
TORCH_ERROR_CODE_CHECK(
aoti_torch_call_dispatcher("aten::clone", "", stack.data()));
#endif
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
// New ops should be added here if they use a brand new shim API
#endif
HIDDEN_NAMESPACE_END(torch, stable)

View File

@ -24,12 +24,17 @@ T to(StableIValue val);
// =============================================================================
// =============================================================================
// FROM CONVERSIONS (T -> StableIValue)
// =============================================================================
// ======================================================================
// Specialization for general copyable types (catch-all) => StableIValue
template <typename T>
struct FromImpl {
static StableIValue call(T val) {
static StableIValue call(
T val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
static_assert(
sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits.");
@ -68,7 +73,12 @@ struct FromImpl {
using torch::headeronly::ScalarType;
template <>
struct FromImpl<ScalarType> {
static StableIValue call(ScalarType val) {
static StableIValue call(
ScalarType val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
switch (val) {
case ScalarType::Byte:
return from(aoti_torch_dtype_uint8());
@ -121,7 +131,12 @@ struct FromImpl<ScalarType> {
// Specialization for std::nullopt_t => StableIValue
template <>
struct FromImpl<std::nullopt_t> {
static StableIValue call(std::nullopt_t val) {
static StableIValue call(
std::nullopt_t val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
return from(nullptr);
}
};
@ -157,11 +172,15 @@ struct FromImpl<std::nullopt_t> {
// std::optional<T> or a std::nullopt.
template <typename T>
struct FromImpl<std::optional<T>> {
static StableIValue call(const std::optional<T>& val) {
static StableIValue call(
const std::optional<T>& val,
uint64_t extension_build_version,
bool is_internal) {
if (!val.has_value()) {
return from(std::nullopt);
}
return from(new StableIValue(from(val.value())));
return from(new StableIValue(detail::FromImpl<T>::call(
val.value(), extension_build_version, is_internal)));
}
};
@ -169,7 +188,12 @@ struct FromImpl<std::optional<T>> {
// Returns a new owning reference of the underlying Tensor.
template <>
struct FromImpl<torch::stable::Tensor> {
static StableIValue call(const torch::stable::Tensor& val) {
static StableIValue call(
const torch::stable::Tensor& val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
AtenTensorHandle new_ath;
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
return from(new_ath);
@ -183,7 +207,12 @@ struct FromImpl<torch::stable::Tensor> {
// Specialization for StableIValue => general copyable types (catch-all)
template <typename T>
struct ToImpl {
static T call(StableIValue val) {
static T call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
static_assert(std::is_trivially_copyable_v<T>);
// T may not have a default constructor. (For example, it might be
// c10::Device.) However, std::memcpy implicitly creates a T at the
@ -218,7 +247,12 @@ struct ToImpl {
// Specialization for StableIValue => torch::headeronly::ScalarType
template <>
struct ToImpl<ScalarType> {
static ScalarType call(StableIValue val) {
static ScalarType call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
int32_t shim_scalartype = to<int32_t>(val);
if (shim_scalartype == aoti_torch_dtype_uint8()) {
return ScalarType::Byte;
@ -273,7 +307,12 @@ struct ToImpl<ScalarType> {
// Specialization for StableIValue => std::nullopt_t
template <>
struct ToImpl<std::nullopt_t> {
static std::nullopt_t call(StableIValue val) {
static std::nullopt_t call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
// val should be equivalent to from(nullptr)
return std::nullopt;
}
@ -284,14 +323,18 @@ struct ToImpl<std::nullopt_t> {
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
template <typename T>
struct ToImpl<std::optional<T>> {
static std::optional<T> call(StableIValue val) {
static std::optional<T> call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
auto sivp = to<StableIValue*>(val);
// sivp is either nullptr or a pointer to a StableIValue
if (sivp == nullptr) {
return {};
}
auto inner_val = to<T>(*sivp);
auto inner_val =
detail::ToImpl<T>::call(*sivp, extension_build_version, is_internal);
// free the memory associated with StableIValue* sivp
delete sivp;
@ -305,7 +348,12 @@ struct ToImpl<std::optional<T>> {
// underlying AtenTensorHandle.
template <>
struct ToImpl<torch::stable::Tensor> {
static torch::stable::Tensor call(StableIValue val) {
static torch::stable::Tensor call(
StableIValue val,
uint64_t extension_build_version,
bool is_internal) {
(void)extension_build_version; // Unused parameter
(void)is_internal; // Unused parameter
return torch::stable::Tensor(to<AtenTensorHandle>(val));
}
};
@ -315,25 +363,60 @@ struct ToImpl<torch::stable::Tensor> {
// =============================================================================
// Expose the partially templated class functions through single functions
// The non-private versions will be used by the extension or headers that
// the extension includes.
template <typename T>
inline StableIValue from(T val) {
return detail::FromImpl<T>::call(val);
return detail::FromImpl<T>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
}
template <typename T>
inline StableIValue from(const std::optional<T>& val) {
return detail::FromImpl<std::optional<T>>::call(val);
return detail::FromImpl<std::optional<T>>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
}
// The below overload is used! See https://godbolt.org/z/859cshxrW
// We are suppressing the warning for versions clang12- and gcc11-
[[maybe_unused]] inline StableIValue from(const torch::stable::Tensor& val) {
return detail::FromImpl<torch::stable::Tensor>::call(val);
return detail::FromImpl<torch::stable::Tensor>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
}
template <typename T>
inline T to(StableIValue val) {
return detail::ToImpl<T>::call(val);
return detail::ToImpl<T>::call(
val, aoti_torch_abi_version(), /*is_internal=*/false);
}
// Internal conversion functions used by from_ivalue and to_ivalue.
// These are used in libtorch
template <typename T>
inline StableIValue _from(T val, uint64_t extension_build_version) {
return detail::FromImpl<T>::call(
val, extension_build_version, /*is_internal=*/true);
}
template <typename T>
inline StableIValue _from(
const std::optional<T>& val,
uint64_t extension_build_version) {
return detail::FromImpl<std::optional<T>>::call(
val, extension_build_version, /*is_internal=*/true);
}
[[maybe_unused]] inline StableIValue _from(
const torch::stable::Tensor& val,
uint64_t extension_build_version) {
return detail::FromImpl<torch::stable::Tensor>::call(
val, extension_build_version, /*is_internal=*/true);
}
template <typename T>
inline T _to(StableIValue val, uint64_t extension_build_version) {
return detail::ToImpl<T>::call(
val, extension_build_version, /*is_internal=*/true);
}
HIDDEN_NAMESPACE_END(torch, stable, detail)

View File

@ -0,0 +1,29 @@
#pragma once
#include <torch/headeronly/version.h>
// Stable ABI Version Targeting
//
// This header provides version targeting capabilities for the PyTorch Stable
// ABI. Users can define TORCH_TARGET_VERSION to target a specific stable ABI
// version instead of using the current TORCH_ABI_VERSION of libtorch at
// compile time.
//
// Usage:
// Default behavior (uses current ABI version):
// #include <torch/csrc/stable/library.h>
//
// Target a specific stable version (major.minor) (e.g. PyTorch 2.9):
// (1) Pass a compiler flag -DTORCH_TARGET_VERSION=0x0209000000000000
// (2) Alternatively, define TORCH_TARGET_VERSION in the source code before
// including any header files:
// #define TORCH_TARGET_VERSION (((0ULL + 2) << 56) | ((0ULL + 9) << 48))
// #include <torch/csrc/stable/library.h>
#ifdef TORCH_TARGET_VERSION
#define TORCH_FEATURE_VERSION TORCH_TARGET_VERSION
#else
#define TORCH_FEATURE_VERSION TORCH_ABI_VERSION
#endif
#define TORCH_VERSION_2_10_0 (((0ULL + 2) << 56) | ((0ULL + 10) << 48))

View File

@ -671,8 +671,6 @@ class DTensor(torch.Tensor):
def __metadata_guard__(
cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool]
) -> bool:
# TODO - delete this - This is now unused after the PR -
# https://github.com/pytorch/pytorch/pull/165824
orig_spec, orig_requires_grad = orig
other_spec, other_requires_grad = other
return (

View File

@ -6612,13 +6612,13 @@ class ShapeEnv:
desc = "Could not guard on data-dependent expression"
size_oblivious_result_msg = (
"consider using data-dependent friendly APIs such as "
"guard_or_false, guard_or_true and statically_known_true"
"guard_or_false, guard_or_true and statically_known_true."
)
msg = (
f"{desc} {expr} (unhinted: {unhinted_expr}). "
f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n"
f"{size_oblivious_result_msg}"
f"{size_oblivious_result_msg}\n"
f"Caused by: {sloc}\n"
'For more information, run with TORCH_LOGS="dynamic"\n'
"For extended logs when we create symbols, also add "

View File

@ -19,8 +19,8 @@
/// Indicates the ABI version of LibTorch as a single uint64.
/// [ byte ][ byte ][ byte ][ byte ][ byte ][ byte ][ byte ][ byte ]
/// [ MAJ ][ MIN ][ PATCH][ ABI TAG ]
#define TORCH_ABI_VERSION \
(uint64_t)TORCH_VERSION_MAJOR << 56 | \
(uint64_t)TORCH_VERSION_MINOR << 48 | \
(uint64_t)TORCH_VERSION_PATCH << 40 | \
TORCH_VERSION_ABI_TAG << 0
#define TORCH_ABI_VERSION ( \
((0ULL + TORCH_VERSION_MAJOR) << 56) | \
((0ULL + TORCH_VERSION_MINOR) << 48) | \
((0ULL + TORCH_VERSION_PATCH) << 40) | \
((0ULL + TORCH_VERSION_ABI_TAG) << 0))

View File

@ -14,11 +14,14 @@ from torch.backends.cuda import (
SDPAParams,
)
from .varlen import varlen_attn
__all__: list[str] = [
"SDPBackend",
"sdpa_kernel",
"WARN_FOR_UNFUSED_KERNELS",
"varlen_attn",
]
# Note: [SDPA warnings]

View File

@ -7,7 +7,7 @@ that calls into the optimized Flash Attention kernels.
import logging
from functools import lru_cache
from typing import Any, NamedTuple, Optional, Union
from typing import NamedTuple, Optional, Union
import torch
@ -33,7 +33,8 @@ class AuxRequest(NamedTuple):
lse: bool = False
@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
# import failures when I try to register as custom op
# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
def _varlen_attn(
query: torch.Tensor,
key: torch.Tensor,
@ -43,7 +44,7 @@ def _varlen_attn(
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Private custom op for variable-length attention.
@ -69,7 +70,7 @@ def _varlen_attn(
False, # return_debug_mask
)
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
output, softmax_lse, rng_state = result[0], result[1], result[6]
output, softmax_lse = result[0], result[1]
else:
log.info("Using Flash Attention backend for varlen_attn")
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
@ -85,13 +86,10 @@ def _varlen_attn(
return_debug_mask=False,
)
rng_state_ = torch.zeros(
(2,), dtype=torch.uint64, device=query.device
) # hardcoded since dropout is hardcoded to 0
return output, softmax_lse, rng_state_
return output, softmax_lse
@_varlen_attn.register_fake
# @_varlen_attn.register_fake
def _varlen_attn_fake(
query: torch.Tensor,
key: torch.Tensor,
@ -101,7 +99,7 @@ def _varlen_attn_fake(
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fake implementation for meta tensor computation and tracing.
@ -119,9 +117,7 @@ def _varlen_attn_fake(
(num_heads, total_q), dtype=torch.float, device=query.device
)
rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device)
return output, logsumexp, rng_state
return output, logsumexp
def varlen_attn(
@ -195,145 +191,9 @@ def varlen_attn(
... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
... )
"""
out, lse, _ = torch.ops.torch_attn._varlen_attn(
out, lse = _varlen_attn(
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
)
if return_aux is not None and return_aux.lse:
return out, lse
return out
def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs
out, lse, rng_state = output
ctx.query = query
ctx.key = key
ctx.value = value
ctx.cu_seq_q = cu_seq_q
ctx.cu_seq_k = cu_seq_k
ctx.max_q = max_q
ctx.max_k = max_k
ctx.is_causal = is_causal
ctx.output = out
ctx.lse = lse
ctx.rng_state = rng_state
@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={})
def _varlen_attn_backward(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool,
rng_state: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
unused = torch.empty(0, device=query.device)
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
if use_cudnn:
log.info("Using cuDNN backend for varlen_attn")
dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0,
is_causal,
rng_state,
unused,
)
else:
log.info("Using Flash Attention backend for varlen_attn")
dq, dk, dv = torch.ops.aten._flash_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0,
is_causal,
rng_state,
unused,
)
return dq, dk, dv
@_varlen_attn_backward.register_fake
def _varlen_attn_backward_fake(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool,
rng_state: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fake implementation for meta tensor computation and tracing.
"""
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
return grad_query, grad_key, grad_value
def _backward(
ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor
) -> tuple[Optional[torch.Tensor], ...]:
query = ctx.query
key = ctx.key
value = ctx.value
cu_seq_q = ctx.cu_seq_q
cu_seq_k = ctx.cu_seq_k
max_q = ctx.max_q
max_k = ctx.max_k
is_causal = ctx.is_causal
out = ctx.output
lse = ctx.lse
rng_state = ctx.rng_state
# rng_state = torch.empty(2, device=query.device)
dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
is_causal,
rng_state,
)
return dq, dk, dv, None, None, None, None, None, None
_varlen_attn.register_autograd(_backward, setup_context=_setup_context)

View File

@ -265,10 +265,8 @@ def _private_register_pytree_node(
)
def _is_pytreespec_instance(
obj: Any, /
) -> TypeIs[Union[TreeSpec, python_pytree.TreeSpec]]:
return isinstance(obj, (TreeSpec, python_pytree.TreeSpec))
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
return isinstance(obj, TreeSpec)
def treespec_leaf() -> TreeSpec:
@ -974,7 +972,7 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
"""Serialize a treespec to a JSON string."""
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"Expected `treespec` to be instance of "
f"treespec_dumps(treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)

View File

@ -20,7 +20,6 @@ import functools
import importlib
import importlib.metadata
import json
import sys
import threading
import types
import warnings
@ -37,11 +36,10 @@ from typing import (
Optional,
overload,
Protocol,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import deprecated, NamedTuple, Self, TypeIs
from typing_extensions import deprecated, NamedTuple, Self
from torch.torch_version import TorchVersion as _TorchVersion
@ -1338,39 +1336,6 @@ def treespec_dict(
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
if TYPE_CHECKING:
import torch.utils._cxx_pytree as cxx
def _is_pytreespec_instance(obj: Any) -> TypeIs[Union[TreeSpec, "cxx.TreeSpec"]]:
if isinstance(obj, TreeSpec):
return True
if "torch.utils._cxx_pytree" in sys.modules:
# The C++ pytree module is not always available, so we check if it is loaded.
# If the C++ pytree module is loaded, we can check if the treespec
# is an instance of the C++ TreeSpec class.
from torch.utils._cxx_pytree import TreeSpec as CxxTreeSpec
if isinstance(obj, CxxTreeSpec):
return True
return False
def _ensure_python_treespec_instance(
treespec: Union[TreeSpec, "cxx.TreeSpec"],
) -> TreeSpec:
if isinstance(treespec, TreeSpec):
return treespec
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
dummy_tree = treespec.unflatten([0] * treespec.num_leaves)
return tree_structure(dummy_tree)
def tree_flatten(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
@ -1401,10 +1366,10 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
"""Given a list of values and a TreeSpec, builds a pytree.
This is the inverse operation of `tree_flatten`.
"""
if not _is_pytreespec_instance(treespec):
if not isinstance(treespec, TreeSpec):
raise TypeError(
f"Expected `treespec` to be an instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
f"instance of TreeSpec but got item of type {type(treespec)}.",
)
return treespec.unflatten(leaves)
@ -1835,30 +1800,34 @@ def _broadcast_to_and_flatten(
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[list[Any]]:
def broadcast_prefix(
prefix_tree: PyTree,
full_tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> list[Any]:
result: list[Any] = []
if not isinstance(treespec, TreeSpec):
raise AssertionError("treespec must be a TreeSpec")
def add_leaves(x: Any, subtree: PyTree) -> None:
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
result.extend([x] * subtreespec.num_leaves)
tree_map_(
add_leaves,
prefix_tree,
full_tree,
is_leaf=is_leaf,
)
return result
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
except ValueError:
if tree_is_leaf(tree, is_leaf=is_leaf):
return [tree] * treespec.num_leaves
if treespec.is_leaf():
return None
node_type = _get_node_type(tree)
if node_type != treespec.type:
return None
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(tree)
# Check if the Node is different from the spec
if len(child_pytrees) != treespec.num_children or context != treespec._context:
return None
# Recursively flatten the children
result: list[Any] = []
for child, child_spec in zip(child_pytrees, treespec._children):
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
if flat is not None:
result += flat
else:
return None
return result
@dataclasses.dataclass
@ -1972,7 +1941,11 @@ _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
treespec = _ensure_python_treespec_instance(treespec)
if not isinstance(treespec, TreeSpec):
raise TypeError(
f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
f"TreeSpec but got item of type {type(treespec)}.",
)
if protocol is None:
protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL