mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 23:15:01 +08:00
Compare commits
16 Commits
ciflow/tru
...
pr166149
| Author | SHA1 | Date | |
|---|---|---|---|
| 4f6a767b3c | |||
| a3fe1825aa | |||
| deb776319b | |||
| d7040e6d75 | |||
| 35f3572fa4 | |||
| bc5111cd8d | |||
| 398fdd32bb | |||
| 5fd1d41e62 | |||
| c594950e86 | |||
| 14102fb1f3 | |||
| 5cdbcb5233 | |||
| eae701cad0 | |||
| 8f51556daa | |||
| c0bbda37e8 | |||
| fefb546b91 | |||
| d6d6fa26f5 |
@ -50,35 +50,18 @@ static inline bool parseLinearFlatten3d() {
|
||||
// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
|
||||
// before passing it to linear operation
|
||||
static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
|
||||
const auto input_sizes = input.sym_sizes();
|
||||
|
||||
const auto result_flattened = [&]() -> Tensor {
|
||||
const auto input_ncols = input_sizes.back();
|
||||
const auto input_flattened_nrows = [&]() -> c10::SymInt {
|
||||
// can't use -1 in reshape because it errors when a dimension is 0
|
||||
auto flattened_nrows = c10::SymInt{1};
|
||||
for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) {
|
||||
flattened_nrows *= size;
|
||||
}
|
||||
return flattened_nrows;
|
||||
}();
|
||||
|
||||
const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols});
|
||||
if (weight.layout() == c10::kStrided) {
|
||||
return at::addmm(bias, input_flattened, weight.t());
|
||||
} else {
|
||||
// weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
|
||||
// so we transpose the problem.
|
||||
// NOTE: at::matmul handles (dense @ sparse) similarly.
|
||||
const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1);
|
||||
return at::addmm(bias_t, weight, input_flattened.t()).t();
|
||||
const auto input_sizes = input.sym_sizes();
|
||||
// can't use -1 in reshape because it errors when a dimension is 0
|
||||
c10::SymInt flattened_dim = 1;
|
||||
for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
|
||||
flattened_dim = flattened_dim * input_sizes[i];
|
||||
}
|
||||
}();
|
||||
|
||||
// Unflatten flattened row dims
|
||||
auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()};
|
||||
result_sizes.back() = result_flattened.sym_size(1);
|
||||
return result_flattened.view_symint(result_sizes);
|
||||
auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
|
||||
const auto result = at::addmm(bias, inp_reshape, weight.t());
|
||||
auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
|
||||
c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
|
||||
sizes_vec.push_back(result.sym_size(1));
|
||||
return result.view_symint(sizes_vec);
|
||||
}
|
||||
|
||||
|
||||
@ -107,23 +90,15 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
|
||||
// Fused op is marginally faster.
|
||||
return at::addmm(*bias, input, weight.t());
|
||||
}
|
||||
|
||||
const auto is_bias_likely_fusable = (
|
||||
bias->defined() &&
|
||||
// cuBLASLt: will fuse in the epilogue without copies
|
||||
// when input/weight/bias are all strided.
|
||||
// When weight is not strided, bias will not be fused,
|
||||
// but we can still dispatch here to avoid at::matmul
|
||||
// path which will probably use a very similar
|
||||
// flattening optimization.
|
||||
(bias->dim() == 1 && bias->is_contiguous_or_false())
|
||||
);
|
||||
if (is_bias_likely_fusable && !input.is_xla()) {
|
||||
// Also hit the fused path for contiguous nD input, if not using xla
|
||||
if (bias->defined() && !input.is_xla()) {
|
||||
// Also hit the fused path for contiguous 3D input, if not using xla
|
||||
// backend. Reshaping/flattening has some performance implications on xla.
|
||||
if (input.is_contiguous_or_false()) {
|
||||
bool is_contiguous = input.is_contiguous_or_false();
|
||||
if (is_contiguous && input_dim == 3) {
|
||||
return _flatten_nd_linear(input, weight, *bias);
|
||||
} else if (parseLinearFlatten3d()) {
|
||||
} else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
|
||||
return _flatten_nd_linear(input, weight, *bias);
|
||||
} else if (parseLinearFlatten3d() && input_dim == 3) {
|
||||
// If user forces flattening via env var
|
||||
const Tensor input_cont = input.contiguous();
|
||||
return _flatten_nd_linear(input_cont, weight, *bias);
|
||||
|
||||
@ -22,9 +22,6 @@
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <ATen/native/hip/ck_group_gemm.h>
|
||||
#endif
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
@ -639,19 +636,12 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <optional>
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::Tensor& out);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
@ -1,458 +0,0 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/TensorAccessor.h>
|
||||
#include <c10/hip/HIPStream.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
|
||||
namespace CkTypes {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
|
||||
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
|
||||
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
|
||||
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
1, 1,
|
||||
S<1,32,1,8>, 4
|
||||
>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
void launch_grouped_bgemm_ck_impl_dispatch(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
at::Tensor& out)
|
||||
{
|
||||
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
|
||||
using PassThrough = CkTypes::PassThrough;
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<const void*> p_a_ptrs, p_b_ptrs;
|
||||
std::vector<void*> p_e_ptrs;
|
||||
// Note: d_ptrs will be resized after we populate the other vectors
|
||||
|
||||
const int mat_a_dim = mat_a.dim();
|
||||
const int mat_b_dim = mat_b.dim();
|
||||
|
||||
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
|
||||
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
|
||||
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
|
||||
const size_t a_element_size = mat_a.element_size();
|
||||
const size_t b_element_size = mat_b.element_size();
|
||||
const size_t out_element_size = out.element_size();
|
||||
|
||||
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
|
||||
if (mat_a_dim == 2 && mat_b_dim == 2) {
|
||||
// 2D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
const int M = mat_a.size(0); // number of rows in A
|
||||
const int N = mat_b.size(1); // number of columns in B
|
||||
const int K = mat_a.size(1); // columns in A == rows in B
|
||||
// for 2d*2d input, output is 3d.
|
||||
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
|
||||
int end_k = offs_accessor[i];
|
||||
int k = end_k - start_k;
|
||||
|
||||
//K dimension are sliced, hence select stride(1) always.
|
||||
//K dimension is always dimension 1, regardless of memory layout (row/column major)
|
||||
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
|
||||
// Leading dimension = distance between rows = stride(0)
|
||||
ldb = mat_b.stride(0);
|
||||
} else {
|
||||
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
|
||||
// Leading dimension = distance between columns = stride(1)
|
||||
ldb = mat_b.stride(1);
|
||||
}
|
||||
|
||||
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
|
||||
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
int lda, ldc;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
}
|
||||
// Output is always row-major in 3D tensor [num_groups, M, N]
|
||||
// Leading dimension for each group's [M,N] slice = stride(1) = N
|
||||
ldc = out.stride(1);
|
||||
size_t output_group_bytes = M * N * out_element_size;
|
||||
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(k),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc)
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
|
||||
// 2D*3D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
|
||||
// 2d*3d input, output is 2d.
|
||||
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
|
||||
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
|
||||
const int K = mat_a.size(1); // columns in A
|
||||
// For 2D-3D case: The output determines N (result width)
|
||||
const int N = out.size(1); // N is the width of the output tensor
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_m = offs_accessor[i];
|
||||
int m = end_m - start_m;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (m <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A rows for group i: skip start_m rows
|
||||
const void* group_a_ptr;
|
||||
int lda;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
lda = mat_a.stride(0); // distance between rows
|
||||
} else {
|
||||
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Detect stride pattern for A tensor to determine appropriate lda calculation
|
||||
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
|
||||
|
||||
if (a_is_strided_tensor) {
|
||||
// For strided A tensors: stride(0) gives the actual leading dimension
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// For non-strided A tensors: use the M dimension (total rows)
|
||||
lda = mat_a.size(0); // Total M dimension for column-major layout
|
||||
}
|
||||
}
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
|
||||
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
|
||||
} else {
|
||||
// Detect stride pattern to determine appropriate ldb calculation
|
||||
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
|
||||
|
||||
if (is_strided_tensor) {
|
||||
// For strided tensors: stride(2) gives the actual leading dimension
|
||||
ldb = mat_b.stride(2);
|
||||
} else {
|
||||
// For non-strided tensors: use the N dimension
|
||||
ldb = mat_b.size(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
|
||||
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
|
||||
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(m),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc)
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
|
||||
// 3d*3d input, output is 3d - batched matrix multiplication
|
||||
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
|
||||
// Each batch is processed as a separate GEMM operation
|
||||
const int batch_size = mat_a.size(0);
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
|
||||
|
||||
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
|
||||
int N;
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
N = mat_b.size(2);
|
||||
} else if (mat_b.size(2) == K) {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
N = mat_b.size(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
|
||||
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
|
||||
// Select output batch for group i: Output[i, :, :]
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
|
||||
int lda, ldb, ldc;
|
||||
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
} else {
|
||||
// Column-major A: leading dimension = distance between columns = stride(2)
|
||||
lda = mat_a.stride(2);
|
||||
}
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B: leading dimension = distance between rows
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(1); // stride between K rows
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
|
||||
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
|
||||
}
|
||||
} else {
|
||||
// Column-major B: leading dimension = distance between columns
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(2); // stride between N columns
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
|
||||
}
|
||||
}
|
||||
|
||||
// Output is typically row-major: leading dimension = distance between rows = stride(1)
|
||||
ldc = out.stride(1);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc)
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
|
||||
// 3D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
// 3d*2d input, output is 3d.
|
||||
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
|
||||
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
|
||||
const int batch_size = mat_a.size(0); // n_groups
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A
|
||||
|
||||
// For row-major A and B case: B should be [K, total_N]
|
||||
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_n = offs_accessor[i];
|
||||
int n = end_n - start_n;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (n <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
// Check if B is row-major or column-major
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(0); // distance between rows (should be total_N)
|
||||
} else {
|
||||
// Column-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(1); // distance between columns (should be K)
|
||||
}
|
||||
|
||||
// Select output slice for group i: Output[:, start_n:end_n]
|
||||
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
|
||||
|
||||
int lda, ldc;
|
||||
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
// Output is row-major: leading dimension = distance between rows = stride(0)
|
||||
ldc = out.stride(0);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(n),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc)
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
|
||||
|
||||
// Initialize d_ptrs with the correct size
|
||||
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
|
||||
|
||||
static DeviceOp gemm_instance;
|
||||
auto argument = gemm_instance.MakeArgument(
|
||||
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
|
||||
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
|
||||
"CK Group GEMM: argument unsupported (shape/strides/type config)");
|
||||
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
|
||||
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
|
||||
|
||||
void* gemm_arg_buf = nullptr;
|
||||
void* ws_buf = nullptr;
|
||||
|
||||
hipMalloc(&gemm_arg_buf, arg_buf_size);
|
||||
hipMalloc(&ws_buf, ws_size);
|
||||
|
||||
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
|
||||
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
|
||||
|
||||
auto invoker = gemm_instance.MakeInvoker();
|
||||
hipStream_t stream = c10::hip::getCurrentHIPStream();
|
||||
invoker.Run(argument, {stream});
|
||||
hipFree(gemm_arg_buf);
|
||||
hipFree(ws_buf);
|
||||
}
|
||||
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& input_a,
|
||||
const at::Tensor& input_b_colmajor,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& /*bias*/,
|
||||
at::Tensor& out)
|
||||
{
|
||||
// Detect if input_a is row-major based on stride pattern
|
||||
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
|
||||
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
|
||||
// Ensure tensor A is row-major and contiguous if not already
|
||||
at::Tensor mat_a = input_a;
|
||||
if (!a_row_major) {
|
||||
// If A is not row-major, make it contiguous (row-major)
|
||||
mat_a = input_a.contiguous();
|
||||
}
|
||||
// Force tensor B to be column-major using double transpose trick
|
||||
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
|
||||
at::Tensor mat_b = input_b_colmajor;
|
||||
if (!b_col_major) {
|
||||
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
|
||||
}
|
||||
|
||||
// For 3D tensors, check the last dimension stride for row-major detection
|
||||
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
|
||||
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
|
||||
|
||||
if (mat_a.dtype() == at::kBFloat16) {
|
||||
// bf16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kHalf) {
|
||||
// fp16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kFloat) {
|
||||
// fp32 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
@ -482,6 +482,7 @@ inductor_core_resources = [
|
||||
"torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp",
|
||||
"torch/csrc/inductor/inductor_ops.cpp",
|
||||
"torch/csrc/jit/serialization/pickle.cpp",
|
||||
"torch/csrc/shim_common.cpp",
|
||||
]
|
||||
|
||||
libtorch_core_sources = sorted(
|
||||
|
||||
@ -464,25 +464,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
run(g, 64, 8)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_dtensor_requires_grad_recompile(self):
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
@torch.compile(backend=cnt, fullgraph=True)
|
||||
def f(x):
|
||||
y = x * x
|
||||
return y.to_local()
|
||||
|
||||
full_x = torch.randn(8, 8, requires_grad=False)
|
||||
x = distribute_tensor(full_x, mesh, [Shard(0)])
|
||||
f(x)
|
||||
|
||||
full_x = torch.randn(8, 8, requires_grad=True)
|
||||
x = distribute_tensor(full_x, mesh, [Shard(0)])
|
||||
f(x)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_dtensor_attribute_access_on_intermediate(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
|
||||
@ -234,6 +234,27 @@ class InPlaceCompilationTests(TestCase):
|
||||
with self.assertRaises(IndexError):
|
||||
fn(torch.randn(10), 99)
|
||||
|
||||
def test_list_bad_weakref(self):
|
||||
import weakref
|
||||
|
||||
a = torch.Event()
|
||||
with self.assertRaises(TypeError):
|
||||
weakref.ref(a)
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self, event):
|
||||
super().__init__()
|
||||
self.event = event
|
||||
|
||||
def forward(self, x):
|
||||
return x * int(self.event.query())
|
||||
|
||||
e = torch.Event()
|
||||
m = Mod(e)
|
||||
a = torch.randn(10)
|
||||
self.assertEqual(m(a), a)
|
||||
|
||||
|
||||
# The private variants of the below functions are extensively tested
|
||||
# So as long as the signatures match we're good
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
@ -16,14 +15,6 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
|
||||
def test_stream_weakref(self):
|
||||
s = torch.Stream()
|
||||
weakref.ref(s)
|
||||
|
||||
def test_event_weakref(self):
|
||||
e = torch.Event()
|
||||
weakref.ref(e)
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck(self):
|
||||
from torch._dynamo.variables.streams import fork_stream, join_stream
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
from unittest import skipIf
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
import torch._inductor.metrics as metrics
|
||||
import torch.utils.flop_counter
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.dependencies import Dep, ReadWrites
|
||||
from torch._inductor.scheduler import BaseSchedulerNode, Scheduler
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
from torch.testing._internal.common_cuda import SM70OrLater
|
||||
from torch.testing._internal.common_device_type import (
|
||||
@ -15,6 +18,7 @@ from torch.testing._internal.common_device_type import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
|
||||
from torch.testing._internal.inductor_utils import IS_BIG_GPU
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def FlopCounterMode(*args, **kwargs):
|
||||
@ -132,6 +136,79 @@ class TestScheduler(TestCase):
|
||||
counters["inductor"]["flop_count"] = 0
|
||||
torch._logging.set_logs()
|
||||
|
||||
def test_fusion_prevent_too_many_reads_and_writes_prevents_fusion(self):
|
||||
"""Test that fusion is prevented when unique I/O buffers exceed threshold"""
|
||||
# Setup: Create nodes with many unique I/O buffers
|
||||
# node1: reads [A, B, C], writes [D]
|
||||
# node2: reads [D, E, F], writes [G]
|
||||
# D becomes internal (node2 reads node1's write)
|
||||
# After fusion: unique I/O = {A, B, C, E, F, G} = 6 buffers
|
||||
scheduler = Mock(spec=Scheduler)
|
||||
scheduler.can_buffer_be_removed_through_fusion = Mock(return_value=False)
|
||||
|
||||
node1 = self._create_mock_node(
|
||||
name="node1", reads=["A", "B", "C"], writes=["D"]
|
||||
)
|
||||
node2 = self._create_mock_node(
|
||||
name="node2", reads=["D", "E", "F"], writes=["G"]
|
||||
)
|
||||
|
||||
# Execute: Check with threshold of 5 (should prevent fusion since 6 > 5)
|
||||
result = Scheduler.fusion_prevent_too_many_reads_and_writes(
|
||||
scheduler, node1, node2, threshold=5
|
||||
)
|
||||
|
||||
# Assert: Fusion should be prevented (6 unique buffers > 5 threshold)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_fusion_prevent_too_many_reads_and_writes_allows_fusion(self):
|
||||
"""Test that fusion is allowed when intermediate buffers are removed"""
|
||||
# Setup: Create nodes where node2 reads node1's output
|
||||
# node1: reads [A, B], writes [C]
|
||||
# node2: reads [C, D], writes [E]
|
||||
# C becomes internal (node2 reads node1's write)
|
||||
# After fusion: unique I/O = {A, B, D, E} = 4 buffers
|
||||
scheduler = Mock(spec=Scheduler)
|
||||
scheduler.can_buffer_be_removed_through_fusion = Mock(return_value=False)
|
||||
|
||||
node1 = self._create_mock_node(name="node1", reads=["A", "B"], writes=["C"])
|
||||
node2 = self._create_mock_node(name="node2", reads=["C", "D"], writes=["E"])
|
||||
|
||||
# Execute: Check with threshold of 5 (should allow fusion since 4 <= 5)
|
||||
result = Scheduler.fusion_prevent_too_many_reads_and_writes(
|
||||
scheduler, node1, node2, threshold=5
|
||||
)
|
||||
|
||||
# Assert: Fusion should be allowed (4 unique buffers <= 5 threshold)
|
||||
self.assertFalse(result)
|
||||
|
||||
def _create_mock_node(self, name: str, reads: list[str], writes: list[str]) -> Mock:
|
||||
"""Helper method to create a mock scheduler node with specified reads/writes"""
|
||||
node = Mock(spec=BaseSchedulerNode)
|
||||
node.get_name = Mock(return_value=name)
|
||||
node.get_nodes = Mock(return_value=[node])
|
||||
|
||||
# Create mock Dep objects for reads and writes
|
||||
read_deps = OrderedSet()
|
||||
for read_name in reads:
|
||||
dep = Mock(spec=Dep)
|
||||
dep.name = read_name
|
||||
read_deps.add(dep)
|
||||
|
||||
write_deps = OrderedSet()
|
||||
for write_name in writes:
|
||||
dep = Mock(spec=Dep)
|
||||
dep.name = write_name
|
||||
write_deps.add(dep)
|
||||
|
||||
# Create mock ReadWrites object
|
||||
read_writes = Mock(spec=ReadWrites)
|
||||
read_writes.reads = read_deps
|
||||
read_writes.writes = write_deps
|
||||
|
||||
node.read_writes = read_writes
|
||||
return node
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestScheduler, globals())
|
||||
|
||||
|
||||
91
test/inductor/test_selective_lowering.py
Normal file
91
test/inductor/test_selective_lowering.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
"""
|
||||
Test selective lowering control via node metadata annotations.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch.testing._internal.common_utils import instantiate_parametrized_tests
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class SelectiveLoweringTest(InductorTestCase):
|
||||
"""
|
||||
Tests for user-controllable selective lowering using node.meta annotations.
|
||||
"""
|
||||
|
||||
device = GPU_TYPE
|
||||
|
||||
def _mark_nodes_for_fallback(
|
||||
self, gm: torch.fx.GraphModule, predicate: Callable[[torch.fx.Node], bool]
|
||||
) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Helper method to mark nodes with should_fallback metadata based on a predicate.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and predicate(node):
|
||||
node.meta["should_fallback"] = True
|
||||
return gm
|
||||
|
||||
def test_basic_selective_lowering(self):
|
||||
"""
|
||||
Test that nodes marked for fallback use fallback handlers instead of lowerings.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
a = x + y # This will be marked for fallback
|
||||
b = a * 2 # This will use normal lowering
|
||||
return b
|
||||
|
||||
x = torch.randn(10, device=self.device)
|
||||
y = torch.randn(10, device=self.device)
|
||||
|
||||
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
|
||||
# Mark all add operations for fallback
|
||||
def should_fallback_add(node: torch.fx.Node) -> bool:
|
||||
return node.target == torch.ops.aten.add.Tensor
|
||||
|
||||
self._mark_nodes_for_fallback(gm, should_fallback_add)
|
||||
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
return compile_fx(gm, example_inputs)
|
||||
|
||||
compiled_fn = torch.compile(foo, backend=custom_backend)
|
||||
result = compiled_fn(x, y)
|
||||
expected = foo(x, y)
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
def test_no_fallback_when_unmarked(self):
|
||||
"""
|
||||
Test that operations without fallback annotation use normal lowering.
|
||||
"""
|
||||
|
||||
def foo(x, y):
|
||||
return x + y
|
||||
|
||||
x = torch.randn(10, device=self.device)
|
||||
y = torch.randn(10, device=self.device)
|
||||
|
||||
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
|
||||
# Don't mark anything - all operations should use normal lowering
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
return compile_fx(gm, example_inputs)
|
||||
|
||||
compiled_fn = torch.compile(foo, backend=custom_backend)
|
||||
result = compiled_fn(x, y)
|
||||
expected = foo(x, y)
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_GPU:
|
||||
run_tests(needs="filelock")
|
||||
@ -459,6 +459,8 @@ class TestMatmulCuda(InductorTestCase):
|
||||
@parametrize("b_row_major", [False, True])
|
||||
@dtypes(torch.bfloat16, torch.float32, torch.float16)
|
||||
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
|
||||
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
|
||||
self.skipTest("failed using hipblaslt on rocm 6.4.2")
|
||||
device = "cuda"
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
|
||||
@ -2598,7 +2598,7 @@ class TestTorchDeviceType(TestCase):
|
||||
dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
|
||||
y = x.clone()
|
||||
x.requires_grad = True
|
||||
d = torch.cdist(x, y)
|
||||
d = torch.cdist(x, y, p=p)
|
||||
d.backward(dist_grad)
|
||||
# Check that the backward pass does not contain invalid
|
||||
# values such as nan or inf
|
||||
|
||||
@ -5,12 +5,11 @@ from collections import namedtuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.attention.varlen import varlen_attn
|
||||
from torch.nn.attention import varlen_attn
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
VarlenShape = namedtuple(
|
||||
@ -24,18 +23,6 @@ default_tolerances = {
|
||||
}
|
||||
|
||||
|
||||
class OpLoggingMode(TorchDispatchMode):
|
||||
"""Logging mode that captures all dispatched operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.called_ops = []
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
op_name = str(func)
|
||||
self.called_ops.append(op_name)
|
||||
return func(*args, **(kwargs or {}))
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
|
||||
@ -52,9 +39,12 @@ class AttentionBlock(nn.Module):
|
||||
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def get_varlen_qkv(
|
||||
def forward_varlen(
|
||||
self,
|
||||
x_packed: torch.Tensor,
|
||||
cu_seq: torch.Tensor,
|
||||
max_len: int,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
qkv = self.qkv_proj(x_packed)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
@ -63,51 +53,24 @@ class AttentionBlock(nn.Module):
|
||||
k = k.view(-1, self.num_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_heads, self.head_dim)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_varlen(
|
||||
self,
|
||||
x_packed: torch.Tensor,
|
||||
cu_seq: torch.Tensor,
|
||||
max_len: int,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
q, k, v = self.get_varlen_qkv(x_packed)
|
||||
|
||||
attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal)
|
||||
attn_out = varlen_attn(
|
||||
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
|
||||
)
|
||||
attn_out = attn_out.view(-1, self.embed_dim)
|
||||
|
||||
return self.out_proj(attn_out)
|
||||
|
||||
def forward_sdpa(
|
||||
self,
|
||||
x_padded: torch.Tensor,
|
||||
seq_lengths: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
|
||||
batch_size, seq_len, _ = x_padded.shape
|
||||
|
||||
qkv = self.qkv_proj(x_padded)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
mask = (
|
||||
torch.arange(seq_len, device=x_padded.device)[None, :]
|
||||
< seq_lengths[:, None]
|
||||
)
|
||||
|
||||
attn_mask = mask[:, None, None, :].expand(
|
||||
batch_size, self.num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, is_causal=is_causal
|
||||
)
|
||||
|
||||
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
|
||||
attn_out = (
|
||||
attn_out.transpose(1, 2)
|
||||
.contiguous()
|
||||
@ -128,9 +91,7 @@ def create_variable_length_batch(
|
||||
seq_lengths = torch.tensor(seq_lengths, device=device)
|
||||
total_tokens = seq_lengths.sum().item()
|
||||
|
||||
x_packed = torch.randn(
|
||||
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
|
||||
|
||||
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
|
||||
cu_seq[1:] = seq_lengths.cumsum(0)
|
||||
@ -145,7 +106,6 @@ def create_variable_length_batch(
|
||||
end_idx = start_idx + seq_len
|
||||
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
x_padded = x_padded.clone().detach().requires_grad_()
|
||||
|
||||
return {
|
||||
"seq_lengths": seq_lengths,
|
||||
@ -173,11 +133,7 @@ class TestVarlenAttention(NNTestCase):
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
total_tokens, shape.embed_dim, device=device, dtype=dtype
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
@ -191,128 +147,6 @@ class TestVarlenAttention(NNTestCase):
|
||||
self.assertEqual(output.device, torch.device(device))
|
||||
self.assertEqual(output.dtype, dtype)
|
||||
|
||||
varlen_grad_out = torch.ones_like(output)
|
||||
|
||||
varlen_grad = torch.autograd.grad(
|
||||
outputs=output,
|
||||
inputs=x_packed,
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
self.assertIsNotNone(varlen_grad)
|
||||
self.assertEqual(varlen_grad.shape, x_packed.shape)
|
||||
self.assertEqual(varlen_grad.dtype, x_packed.dtype)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_custom_op_compliance(self, device, dtype):
|
||||
torch.manual_seed(42)
|
||||
|
||||
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
|
||||
|
||||
attention_block = AttentionBlock(
|
||||
shape.embed_dim, shape.num_heads, device, dtype
|
||||
)
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
q, k, v = attention_block.get_varlen_qkv(x_packed)
|
||||
|
||||
torch.library.opcheck(
|
||||
torch.ops.torch_attn._varlen_attn,
|
||||
(q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False),
|
||||
)
|
||||
|
||||
out, lse, rng_state = torch.ops.torch_attn._varlen_attn(
|
||||
q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False
|
||||
)
|
||||
grad_out = torch.randn_like(out)
|
||||
|
||||
# we don't support double backward
|
||||
# skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static
|
||||
torch.library.opcheck(
|
||||
torch.ops.torch_attn._varlen_attn_backward,
|
||||
(
|
||||
grad_out,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
lse,
|
||||
cu_seq,
|
||||
cu_seq,
|
||||
shape.max_seq_len,
|
||||
shape.max_seq_len,
|
||||
False,
|
||||
rng_state,
|
||||
),
|
||||
test_utils=["test_schema", "test_faketensor"],
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_custom_op_registration(self, device, dtype):
|
||||
torch.manual_seed(42)
|
||||
|
||||
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
|
||||
|
||||
attention_block = AttentionBlock(
|
||||
shape.embed_dim, shape.num_heads, device, dtype
|
||||
)
|
||||
|
||||
total_tokens = shape.batch_size * shape.max_seq_len
|
||||
x_packed = torch.randn(
|
||||
total_tokens,
|
||||
shape.embed_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
cu_seq = torch.tensor(
|
||||
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
compiled_forward = torch.compile(
|
||||
attention_block.forward_varlen, backend="eager", fullgraph=True
|
||||
)
|
||||
with OpLoggingMode() as mode:
|
||||
output = compiled_forward(
|
||||
x_packed, cu_seq, shape.max_seq_len, is_causal=False
|
||||
)
|
||||
|
||||
varlen_grad_out = torch.ones_like(output)
|
||||
_ = torch.autograd.grad(
|
||||
outputs=output,
|
||||
inputs=x_packed,
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
called_ops = mode.called_ops
|
||||
|
||||
custom_ops_called = any(
|
||||
"torch_attn._varlen_attn" in op for op in called_ops
|
||||
) and any("torch_attn._varlen_attn_backward" in op for op in called_ops)
|
||||
assert custom_ops_called
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||
)
|
||||
@ -338,10 +172,7 @@ class TestVarlenAttention(NNTestCase):
|
||||
is_causal=is_causal,
|
||||
)
|
||||
sdpa_output = attention_block.forward_sdpa(
|
||||
variable_length_batch_data["x_padded"],
|
||||
variable_length_batch_data["seq_lengths"],
|
||||
dtype=dtype,
|
||||
is_causal=is_causal,
|
||||
variable_length_batch_data["x_padded"], is_causal=is_causal
|
||||
)
|
||||
|
||||
tolerances = default_tolerances[dtype]
|
||||
@ -355,44 +186,6 @@ class TestVarlenAttention(NNTestCase):
|
||||
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
|
||||
start_idx = end_idx
|
||||
|
||||
varlen_grad_out = torch.ones_like(varlen_output)
|
||||
|
||||
sdpa_grad_out = torch.zeros_like(sdpa_output)
|
||||
|
||||
start_idx = 0
|
||||
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
|
||||
end_idx = start_idx + seq_len
|
||||
sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
|
||||
varlen_grad = torch.autograd.grad(
|
||||
outputs=varlen_output,
|
||||
inputs=variable_length_batch_data["x_packed"],
|
||||
grad_outputs=varlen_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
sdpa_grad = torch.autograd.grad(
|
||||
outputs=sdpa_output,
|
||||
inputs=variable_length_batch_data["x_padded"],
|
||||
grad_outputs=sdpa_grad_out,
|
||||
retain_graph=True,
|
||||
create_graph=False,
|
||||
allow_unused=False,
|
||||
)[0]
|
||||
|
||||
start_idx = 0
|
||||
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
|
||||
end_idx = start_idx + seq_len
|
||||
|
||||
varlen_grad_seq = varlen_grad[start_idx:end_idx]
|
||||
sdpa_grad_seq = sdpa_grad[i, :seq_len]
|
||||
|
||||
torch.testing.assert_close(varlen_grad_seq, sdpa_grad_seq, **tolerances)
|
||||
start_idx = end_idx
|
||||
|
||||
|
||||
device_types = ("cuda",)
|
||||
|
||||
|
||||
@ -374,6 +374,22 @@ def build_collectives(
|
||||
return tracebacks, collectives, nccl_calls
|
||||
|
||||
|
||||
def transform_ft(
|
||||
details: dict[str, dict[str, Any]], group_world_size: int
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
for dump_key, dump in details.items():
|
||||
rank = dump["rank"]
|
||||
for key, pg_config in dump["pg_config"].items():
|
||||
if pg_config["desc"] == "default_pg":
|
||||
ranks = eval(pg_config["ranks"])
|
||||
replica_id = rank // group_world_size
|
||||
first_rank = replica_id * group_world_size
|
||||
new_ranks = [r + first_rank for r in ranks]
|
||||
details[dump_key]["pg_config"][key]["ranks"] = f"{new_ranks}"
|
||||
|
||||
return details
|
||||
|
||||
|
||||
def build_db(
|
||||
details: dict[str, dict[str, Any]], args: argparse.Namespace, version: str
|
||||
) -> Database:
|
||||
|
||||
@ -74,6 +74,17 @@ class JobConfig:
|
||||
default=10,
|
||||
help="Maximum number of mismatches we print (from earliest).",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--transform-ft",
|
||||
action="store_true",
|
||||
help="Transform PG config to use global ranks to analyze traces produced by torchft",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--group-world-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of ranks in 1 torchft replica group. Must be specified if --transform-ft is True",
|
||||
)
|
||||
|
||||
def parse_args(
|
||||
self: "JobConfig", args: Optional[Sequence[str]]
|
||||
|
||||
@ -32,7 +32,7 @@ import pickle
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from tools.flight_recorder.components.builder import build_db
|
||||
from tools.flight_recorder.components.builder import build_db, transform_ft
|
||||
from tools.flight_recorder.components.config_manager import JobConfig
|
||||
from tools.flight_recorder.components.loader import read_dir
|
||||
from tools.flight_recorder.components.types import types
|
||||
@ -46,6 +46,9 @@ def main(args: Optional[Sequence[str]] = None) -> None:
|
||||
assert args.trace_dir, "Trace directory trace_dir is required"
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
details, version = read_dir(args)
|
||||
if args.transform_ft:
|
||||
assert args.group_world_size, "World size is required for transform_ft"
|
||||
details = transform_ft(details, args.group_world_size)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
db = build_db(details, args, version)
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
|
||||
@ -2150,19 +2150,6 @@ class GuardBuilder(GuardBuilderBase):
|
||||
metadata_checker, get_verbose_code_parts(global_name, guard)
|
||||
)
|
||||
|
||||
def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None:
|
||||
# Copied from DTensor __metadata_guard__
|
||||
# TODO - Consider moving this to C++ if stable
|
||||
value = deepcopy(self.get(guard.name))
|
||||
|
||||
def guard_fn(x: Any) -> bool:
|
||||
return x._check_equals(value, skip_shapes=True)
|
||||
|
||||
code = f"__dtensor_spec_{id(guard_fn)}"
|
||||
self.get_guard_manager(guard).add_lambda_guard(
|
||||
guard_fn, get_verbose_code_parts(code, guard)
|
||||
)
|
||||
|
||||
def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
|
||||
ref = self.arg_ref(guard)
|
||||
val = self.get(guard.name)
|
||||
|
||||
@ -346,10 +346,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
assert callable(self._unflatten_func)
|
||||
return self._unflatten_func(self._metadata, subtrees)
|
||||
|
||||
def _is_pytreespec_instance(
|
||||
obj: Any, /
|
||||
) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]:
|
||||
return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec))
|
||||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
||||
return isinstance(obj, PyTreeSpec)
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree.treespec_leaf,
|
||||
@ -552,7 +550,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"Expected `treespec` to be an instance of "
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
return treespec.unflatten(leaves)
|
||||
|
||||
@ -2229,70 +2229,25 @@ class VariableBuilder:
|
||||
if isinstance(source, GradSource) and is_from_optimizer_source(source):
|
||||
guard_type = GuardBuilder.NOT_NONE_MATCH
|
||||
|
||||
is_dtensor = torch.distributed.is_available() and isinstance(
|
||||
value, torch.distributed.tensor.DTensor
|
||||
)
|
||||
if not is_dtensor:
|
||||
# We guard on the _local_tensor and the _spec, and therefore we dont
|
||||
# have to guard on the outer DTensor.
|
||||
self.install_guards(
|
||||
functools.partial(
|
||||
guard_type,
|
||||
value=(
|
||||
value
|
||||
if isinstance(source, NumpyTensorSource)
|
||||
else TensorWeakRef(value)
|
||||
),
|
||||
)
|
||||
self.install_guards(
|
||||
functools.partial(
|
||||
guard_type,
|
||||
value=(
|
||||
value
|
||||
if isinstance(source, NumpyTensorSource)
|
||||
else TensorWeakRef(value)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# We install TYPE_MATCH guards for traceable wrapper subclass object,
|
||||
# and recursively install corresponding guard for each inner attribute.
|
||||
if is_traceable_wrapper_subclass(value):
|
||||
# Tensor subclass guards are very expensive because they are
|
||||
# implemented in Python. Since DTensor is PyTorch-maintained class,
|
||||
# we can skip a lot of these guards.
|
||||
if is_dtensor:
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
|
||||
# The inner tensor name is always _local_tensor. If its not, we
|
||||
# raise assertion to update the check accordingly.
|
||||
inner_tensor_name = value.__tensor_flatten__()[0][0]
|
||||
if inner_tensor_name != "_local_tensor":
|
||||
raise RuntimeError(
|
||||
"Expecting Dtensor inner tensor name to be _local_tensor"
|
||||
)
|
||||
|
||||
# Now selectively guard on the flattening context
|
||||
flattening_ctx = value.__tensor_flatten__()[1]
|
||||
# This is supposed to be (self._spec, self.requires_grad)
|
||||
if not (
|
||||
len(flattening_ctx) == 2
|
||||
and flattening_ctx[0] == value._spec
|
||||
and flattening_ctx[1] == value.requires_grad
|
||||
):
|
||||
# If not, raise an assertion to update to the new guards
|
||||
raise RuntimeError(
|
||||
"Expecting Dtensor flattening ctx to be _spec, requires_grad"
|
||||
)
|
||||
# Guard on the dtensor spec
|
||||
install_guard(
|
||||
AttrSource(self.source, "_spec").make_guard(
|
||||
GuardBuilder.DTENSOR_SPEC_MATCH
|
||||
)
|
||||
)
|
||||
# Move this to C++
|
||||
install_guard(
|
||||
AttrSource(self.source, "requires_grad").make_guard(
|
||||
GuardBuilder.EQUALS_MATCH
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
install_guard(
|
||||
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
|
||||
)
|
||||
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
install_guard(
|
||||
SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH)
|
||||
)
|
||||
|
||||
attrs, _ = value.__tensor_flatten__()
|
||||
for attr in attrs:
|
||||
|
||||
@ -530,6 +530,17 @@ class InductorChoices:
|
||||
WhyNoFuse(node1, node2)("Fusion will increase peak memory")
|
||||
return False
|
||||
|
||||
if (
|
||||
config.max_fusion_unique_io_buffers is not None
|
||||
and scheduler.fusion_prevent_too_many_reads_and_writes(
|
||||
node1,
|
||||
node2,
|
||||
config.max_fusion_unique_io_buffers,
|
||||
)
|
||||
):
|
||||
WhyNoFuse(node1, node2)("fusion_prevent_too_many_reads_and_writes")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -688,6 +688,10 @@ max_fusion_size = 64
|
||||
# how many nodes to attempt pairwise fusion with in a buffer group
|
||||
max_fusion_buffer_group_pairwise_attempts = 64
|
||||
|
||||
# maximum number of unique input/output buffers allowed in fused kernels.
|
||||
# The check is disabled if set to None.
|
||||
max_fusion_unique_io_buffers: Optional[int] = None
|
||||
|
||||
# max number of inputs to generate cat as a pointwise op with masked loads
|
||||
max_pointwise_cat_inputs = 8
|
||||
|
||||
|
||||
@ -1322,7 +1322,12 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
else:
|
||||
args, kwargs = layout_constraints(n, *args, **kwargs)
|
||||
|
||||
out = lowerings[target](*args, **kwargs) # type: ignore[index]
|
||||
if "should_fallback" in n.meta:
|
||||
out = fallback_handler(target, add_to_fallback_set=False)(
|
||||
*args, **kwargs
|
||||
)
|
||||
else:
|
||||
out = lowerings[target](*args, **kwargs) # type: ignore[index]
|
||||
|
||||
if layout_constraints:
|
||||
# layout_constraints are allowed to make new copies of the inputs.
|
||||
|
||||
@ -4113,6 +4113,54 @@ class Scheduler:
|
||||
return True
|
||||
return False
|
||||
|
||||
def fusion_prevent_too_many_reads_and_writes(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int
|
||||
) -> bool:
|
||||
# After fusion, we need to calculate the unique I/O buffers
|
||||
# accounting for buffers that become internal (removed through fusion)
|
||||
|
||||
# Get all nodes that will be in the fused node
|
||||
fused_node_names = OrderedSet(
|
||||
[node.get_name() for node in node1.get_nodes()]
|
||||
+ [node.get_name() for node in node2.get_nodes()]
|
||||
)
|
||||
|
||||
# Calculate node2 reads that can be removed through fusion,
|
||||
# i.e. node2 reads that are outputs of node1
|
||||
node1_write_names = OrderedSet(dep.name for dep in node1.read_writes.writes)
|
||||
node2_read_names = OrderedSet(dep.name for dep in node2.read_writes.reads)
|
||||
reads_removed_through_fusion = node2_read_names & node1_write_names
|
||||
|
||||
# Calculate node1 writes that can be removed through fusion,
|
||||
# i.e. node1 writes that are only read by node2
|
||||
writes_removed_through_fusion: OrderedSet[str] = OrderedSet()
|
||||
for write_dep in node1.read_writes.writes:
|
||||
if self.can_buffer_be_removed_through_fusion(
|
||||
write_dep.name, fused_node_names
|
||||
):
|
||||
writes_removed_through_fusion.add(write_dep.name)
|
||||
|
||||
# Get all unique reads (union of both nodes' reads)
|
||||
all_read_names = OrderedSet(
|
||||
dep.name for dep in node1.read_writes.reads
|
||||
) | OrderedSet(dep.name for dep in node2.read_writes.reads)
|
||||
|
||||
# Get all unique writes (union of both nodes' writes)
|
||||
all_write_names = OrderedSet(
|
||||
dep.name for dep in node1.read_writes.writes
|
||||
) | OrderedSet(dep.name for dep in node2.read_writes.writes)
|
||||
|
||||
# Remove reads that become internal
|
||||
unique_reads = all_read_names - reads_removed_through_fusion
|
||||
|
||||
# Remove writes that become internal
|
||||
unique_writes = all_write_names - writes_removed_through_fusion
|
||||
|
||||
# Get all unique buffer names (reads and writes combined, but no double counting)
|
||||
unique_io_buffers = unique_reads | unique_writes
|
||||
|
||||
return len(unique_io_buffers) > threshold
|
||||
|
||||
def are_long_distant_nodes(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
|
||||
@ -49,7 +49,6 @@ static PyObject* THPEvent_pynew(
|
||||
}
|
||||
|
||||
THPEvent* self = reinterpret_cast<THPEvent*>(ptr.get());
|
||||
self->weakreflist = nullptr;
|
||||
|
||||
// TODO: blocking and interprocess are not supported yet. To support them, the
|
||||
// flag system of c10::Event needs to be refactored. C10::Event should also
|
||||
@ -74,7 +73,6 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
|
||||
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
|
||||
TORCH_CHECK(self, "Failed to allocate memory for Event");
|
||||
auto self_ = reinterpret_cast<THPEvent*>(self.get());
|
||||
self_->weakreflist = nullptr;
|
||||
new (&self_->event) c10::Event(device_type, flag);
|
||||
return self.release();
|
||||
}
|
||||
@ -84,7 +82,6 @@ static void THPEvent_dealloc(THPEvent* self) {
|
||||
pybind11::gil_scoped_release no_gil{};
|
||||
self->event.~Event();
|
||||
}
|
||||
PyObject_ClearWeakRefs((PyObject*)self);
|
||||
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
|
||||
}
|
||||
|
||||
@ -285,8 +282,7 @@ static PyMethodDef THPEvent_methods[] = {
|
||||
{"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
|
||||
{"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
|
||||
{nullptr}};
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Winvalid-offsetof"
|
||||
|
||||
PyTypeObject THPEventType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch.Event", /* tp_name */
|
||||
@ -312,7 +308,7 @@ PyTypeObject THPEventType = {
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
THPEvent_methods, /* tp_methods */
|
||||
@ -327,7 +323,6 @@ PyTypeObject THPEventType = {
|
||||
nullptr, /* tp_alloc */
|
||||
THPEvent_pynew, /* tp_new */
|
||||
};
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
void THPEvent_init(PyObject* module) {
|
||||
THPEventClass = &THPEventType;
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
struct TORCH_API THPEvent {
|
||||
PyObject_HEAD
|
||||
c10::Event event;
|
||||
PyObject* weakreflist;
|
||||
};
|
||||
TORCH_API extern PyTypeObject* THPEventClass;
|
||||
TORCH_API extern PyTypeObject THPEventType;
|
||||
|
||||
@ -95,7 +95,6 @@ static PyObject* THPStream_pynew(
|
||||
self->device_index = static_cast<int64_t>(stream_opt->device_index());
|
||||
self->device_type = static_cast<int64_t>(stream_opt->device_type());
|
||||
self->context = nullptr;
|
||||
self->weakreflist = nullptr;
|
||||
|
||||
return static_cast<PyObject*>(ptr.release());
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -115,13 +114,11 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
|
||||
self->device_index = static_cast<int64_t>(stream.device_index());
|
||||
self->device_type = static_cast<int64_t>(stream.device_type());
|
||||
self->context = nullptr;
|
||||
self->weakreflist = nullptr;
|
||||
return ptr.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static void THPStream_dealloc(THPStream* self) {
|
||||
PyObject_ClearWeakRefs((PyObject*)self);
|
||||
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
|
||||
}
|
||||
|
||||
@ -447,7 +444,7 @@ static PyTypeObject THPStreamType = {
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
THPStream_richcompare, /* tp_richcompare */
|
||||
offsetof(THPStream, weakreflist), /* tp_weaklistoffset */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
// NOLINTNEXTLINE(*const-cast)
|
||||
|
||||
@ -13,7 +13,6 @@ struct THPStream {
|
||||
int64_t device_index;
|
||||
// Used to switch stream context management, initialized lazily.
|
||||
PyObject* context;
|
||||
PyObject* weakreflist;
|
||||
};
|
||||
extern TORCH_API PyTypeObject* THPStreamClass;
|
||||
|
||||
|
||||
@ -1406,169 +1406,6 @@ AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
|
||||
});
|
||||
}
|
||||
|
||||
static StableIValue from_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const c10::IValue& ivalue) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
|
||||
return torch::stable::detail::from(ath);
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return torch::stable::detail::from(ivalue.toInt());
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return torch::stable::detail::from(ivalue.toDouble());
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return torch::stable::detail::from(ivalue.toBool());
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return torch::stable::detail::from(ivalue.toScalarType());
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return torch::stable::detail::from(ivalue.toDevice());
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return torch::stable::detail::from(ivalue.toLayout());
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return torch::stable::detail::from(ivalue.toMemoryFormat());
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
||||
// ideally, if we had the C++ type corresponding to inner_type, which we
|
||||
// will denote as inner_type::t (does not actually exist), we would be
|
||||
// able to follow the patterned semantic of every other case here in one
|
||||
// line:
|
||||
//
|
||||
// return
|
||||
// torch::stable::detail::from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
|
||||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with torch::stable::detail::from<std::optional<T>>
|
||||
// function in torch/csrc/stable/stableivalue_conversions.h
|
||||
if (ivalue.isNone()) {
|
||||
return torch::stable::detail::from(std::nullopt);
|
||||
}
|
||||
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
|
||||
return torch::stable::detail::from(sivp);
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported conversion from IValue to StableIValue for schema type: ",
|
||||
type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static c10::IValue to_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const StableIValue stable_ivalue) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
|
||||
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get())));
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return c10::IValue(
|
||||
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return c10::IValue(
|
||||
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
||||
// ideally, if we had the C++ type corresponding to inner_type, which we
|
||||
// will denote as inner_type::t (does not actually exist), we would be
|
||||
// able to follow the patterned semantic of every other case here in one
|
||||
// line:
|
||||
//
|
||||
// return
|
||||
// c10::IValue(torch::stable::detail::to<std::optional<inner_type::t>>(stable_ivalue));
|
||||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with the torch::stable::detail::to<T> function in
|
||||
// torch/csrc/stable/stableivalue_conversions.h
|
||||
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
|
||||
return c10::IValue();
|
||||
}
|
||||
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
|
||||
auto ival = to_ivalue(inner_type, *sivp);
|
||||
delete sivp;
|
||||
return ival;
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported conversion from StableIValue to IValue for schema type: ",
|
||||
type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||
public:
|
||||
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
|
||||
: fn_(fn) {}
|
||||
|
||||
void operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
auto ministack =
|
||||
std::make_unique<StableIValue[]>(std::max(num_arguments, num_returns));
|
||||
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
const auto ministack_idx = num_arguments - idx - 1;
|
||||
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
|
||||
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
|
||||
}
|
||||
|
||||
// boxed function is going to take a stack of StableIValues, cast them to
|
||||
// our schema values, and run the function and modify the StableIValue stack
|
||||
fn_(ministack.get(), num_arguments, num_returns);
|
||||
|
||||
// read the output from the end of the stack and wrap that back into
|
||||
// IValue from StableIValue
|
||||
for (size_t idx = 0; idx < num_returns; idx++) {
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void (*fn_)(StableIValue*, uint64_t, uint64_t);
|
||||
};
|
||||
|
||||
AOTITorchError aoti_torch_library_init_impl(
|
||||
const char* ns,
|
||||
const char* k,
|
||||
@ -1618,18 +1455,6 @@ AOTITorchError aoti_torch_library_init_fragment(
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
reinterpret_cast<torch::Library*>(self)->impl(
|
||||
name,
|
||||
torch::CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<StableIValueBoxedKernel>(fn)));
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_library_def(TorchLibraryHandle self, const char* name) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
@ -1642,40 +1467,6 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
|
||||
{ delete reinterpret_cast<torch::Library*>(tlh); });
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_call_dispatcher(
|
||||
const char* opName,
|
||||
const char* overloadName,
|
||||
StableIValue* stack) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
const auto op =
|
||||
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
torch::jit::Stack ivalue_stack;
|
||||
// we will only need max(num_args, num_returns)
|
||||
ivalue_stack.reserve(std::max(num_arguments, num_returns));
|
||||
|
||||
// convert StableIValue stack to c10::IValue stack
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
auto stable_ivalue = stack[idx];
|
||||
auto arg_type = schema.arguments()[idx].type();
|
||||
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
|
||||
}
|
||||
|
||||
op.callBoxed(ivalue_stack);
|
||||
|
||||
// there should then be num_returns IValues on the stack, which
|
||||
// we will convert to StableIValue and repopulate user input stack
|
||||
for (const auto idx : c10::irange(num_returns)) {
|
||||
const auto stack_idx = num_returns - idx - 1;
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_create_device_guard(
|
||||
int32_t device_index,
|
||||
DeviceGuardHandle* ret_guard // returns new reference
|
||||
|
||||
@ -260,82 +260,20 @@ typedef __half half;
|
||||
)";
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION < 70000
|
||||
#if defined(USE_ROCM)
|
||||
|
||||
#if ROCM_VERSION >= 70000
|
||||
#define BF16_UINT32_DEF "typedef unsigned int uint32_t;\n"
|
||||
#else
|
||||
#define BF16_UINT32_DEF ""
|
||||
#endif
|
||||
|
||||
constexpr auto bfloat16_support_literal =
|
||||
R"(
|
||||
#ifndef __align__
|
||||
#define __align__(x) __attribute__((aligned(x)))
|
||||
#endif
|
||||
|
||||
typedef struct __align__(2) {
|
||||
unsigned short x;
|
||||
}
|
||||
__nv_bfloat16_raw;
|
||||
|
||||
#if defined(__cplusplus)
|
||||
struct __align__(2) __nv_bfloat16 {
|
||||
__host__ __device__ __nv_bfloat16() {}
|
||||
|
||||
__host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
|
||||
__x = hr.x;
|
||||
return *this;
|
||||
}
|
||||
|
||||
unsigned short __x;
|
||||
};
|
||||
|
||||
__device__ unsigned short __internal_float2bfloat16(
|
||||
const float f,
|
||||
unsigned int& sign,
|
||||
unsigned int& remainder) {
|
||||
unsigned int x;
|
||||
|
||||
x = __float_as_uint(f);
|
||||
|
||||
if ((x & 0x7fffffffU) > 0x7f800000U) {
|
||||
sign = 0U;
|
||||
remainder = 0U;
|
||||
return static_cast<unsigned short>(0x7fffU);
|
||||
}
|
||||
sign = x >> 31;
|
||||
remainder = x << 16;
|
||||
return static_cast<unsigned short>(x >> 16);
|
||||
}
|
||||
|
||||
/* Definitions of intrinsics */
|
||||
__device__ __nv_bfloat16 __float2bfloat16(const float a) {
|
||||
__nv_bfloat16 val;
|
||||
__nv_bfloat16_raw r;
|
||||
unsigned int sign;
|
||||
unsigned int remainder;
|
||||
r.x = __internal_float2bfloat16(a, sign, remainder);
|
||||
if ((remainder > 0x80000000U) ||
|
||||
((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
|
||||
r.x++;
|
||||
}
|
||||
val = r;
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ float __bfloat162float(const __nv_bfloat16 a) {
|
||||
union
|
||||
{
|
||||
uint32_t int32;
|
||||
float fp32;
|
||||
} u = {uint32_t(a.__x) << 16};
|
||||
return u.fp32;
|
||||
}
|
||||
#endif /* defined(__cplusplus) */
|
||||
)";
|
||||
#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
|
||||
constexpr auto bfloat16_support_literal =
|
||||
R"(
|
||||
#ifndef __align__
|
||||
#define __align__(x) __attribute__((aligned(x)))
|
||||
#endif
|
||||
|
||||
typedef unsigned int uint32_t;
|
||||
|
||||
)" BF16_UINT32_DEF R"(
|
||||
typedef struct __align__(2) {
|
||||
unsigned short x;
|
||||
}
|
||||
|
||||
417
torch/csrc/shim_common.cpp
Normal file
417
torch/csrc/shim_common.cpp
Normal file
@ -0,0 +1,417 @@
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
|
||||
static StableIValue from_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const c10::IValue& ivalue,
|
||||
uint64_t extension_build_version) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
|
||||
return torch::stable::detail::_from(ath, extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toInt(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toDouble(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toBool(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toScalarType(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toDevice(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toLayout(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return torch::stable::detail::_from(
|
||||
ivalue.toMemoryFormat(), extension_build_version);
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
||||
// ideally, if we had the C++ type corresponding to inner_type, which we
|
||||
// will denote as inner_type::t (does not actually exist), we would be
|
||||
// able to follow the patterned semantic of every other case here in one
|
||||
// line:
|
||||
//
|
||||
// return
|
||||
// torch::stable::detail::from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
|
||||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with torch::stable::detail::from<std::optional<T>>
|
||||
// function in torch/csrc/stable/stableivalue_conversions.h
|
||||
if (ivalue.isNone()) {
|
||||
return torch::stable::detail::_from(
|
||||
std::nullopt, extension_build_version);
|
||||
}
|
||||
StableIValue* sivp = new StableIValue(
|
||||
from_ivalue(inner_type, ivalue, extension_build_version));
|
||||
return torch::stable::detail::_from(sivp, extension_build_version);
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported conversion from IValue to StableIValue for schema type: ",
|
||||
type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static c10::IValue to_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const StableIValue stable_ivalue,
|
||||
uint64_t extension_build_version) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
torch::stable::detail::_to<AtenTensorHandle>(
|
||||
stable_ivalue, extension_build_version));
|
||||
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get())));
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return c10::IValue(torch::stable::detail::_to<int64_t>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return c10::IValue(torch::stable::detail::_to<double>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return c10::IValue(torch::stable::detail::_to<bool>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return c10::IValue(torch::stable::detail::_to<c10::ScalarType>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return c10::IValue(torch::stable::detail::_to<c10::Device>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return c10::IValue(torch::stable::detail::_to<c10::Layout>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return c10::IValue(torch::stable::detail::_to<c10::MemoryFormat>(
|
||||
stable_ivalue, extension_build_version));
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
||||
// ideally, if we had the C++ type corresponding to inner_type, which we
|
||||
// will denote as inner_type::t (does not actually exist), we would be
|
||||
// able to follow the patterned semantic of every other case here in one
|
||||
// line:
|
||||
//
|
||||
// return
|
||||
// c10::IValue(torch::stable::detail::to<std::optional<inner_type::t>>(stable_ivalue));
|
||||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with the torch::stable::detail::_to<T> function in
|
||||
// torch/csrc/stable/library.h
|
||||
if (stable_ivalue ==
|
||||
torch::stable::detail::_from(std::nullopt, extension_build_version)) {
|
||||
return c10::IValue();
|
||||
}
|
||||
auto sivp = torch::stable::detail::_to<StableIValue*>(
|
||||
stable_ivalue, extension_build_version);
|
||||
auto ival = to_ivalue(inner_type, *sivp, extension_build_version);
|
||||
delete sivp;
|
||||
return ival;
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported conversion from StableIValue to IValue for schema type: ",
|
||||
type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||
public:
|
||||
StableIValueBoxedKernel(
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t),
|
||||
uint64_t extension_build_version)
|
||||
: fn_(fn), extension_build_version_(extension_build_version) {}
|
||||
|
||||
void operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
auto ministack =
|
||||
std::make_unique<StableIValue[]>(std::max(num_arguments, num_returns));
|
||||
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
const auto ministack_idx = num_arguments - idx - 1;
|
||||
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
|
||||
ministack[ministack_idx] = from_ivalue(
|
||||
arg_type, torch::jit::pop(stack), extension_build_version_);
|
||||
}
|
||||
|
||||
// boxed function is going to take a stack of StableIValues, cast them to
|
||||
// our schema values, and run the function and modify the StableIValue stack
|
||||
fn_(ministack.get(), num_arguments, num_returns);
|
||||
|
||||
// read the output from the end of the stack and wrap that back into
|
||||
// IValue from StableIValue
|
||||
for (size_t idx = 0; idx < num_returns; idx++) {
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
torch::jit::push(
|
||||
stack, to_ivalue(ret_type, ministack[idx], extension_build_version_));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void (*fn_)(StableIValue*, uint64_t, uint64_t);
|
||||
uint64_t extension_build_version_;
|
||||
};
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
reinterpret_cast<torch::Library*>(self)->impl(
|
||||
name,
|
||||
torch::CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<StableIValueBoxedKernel>(fn, TORCH_ABI_VERSION)));
|
||||
});
|
||||
}
|
||||
|
||||
// Version-aware variant of aoti_torch_library_impl that takes an
|
||||
// extension_build_version parameter for backward compatibility
|
||||
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t),
|
||||
uint64_t extension_build_version) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
reinterpret_cast<torch::Library*>(self)->impl(
|
||||
name,
|
||||
torch::CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<StableIValueBoxedKernel>(
|
||||
fn, extension_build_version)));
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_call_dispatcher(
|
||||
const char* opName,
|
||||
const char* overloadName,
|
||||
StableIValue* stack) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
const auto op =
|
||||
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
torch::jit::Stack ivalue_stack;
|
||||
// we will only need max(num_args, num_returns)
|
||||
ivalue_stack.reserve(std::max(num_arguments, num_returns));
|
||||
|
||||
// convert StableIValue stack to c10::IValue stack
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
auto stable_ivalue = stack[idx];
|
||||
auto arg_type = schema.arguments()[idx].type();
|
||||
torch::jit::push(
|
||||
ivalue_stack, to_ivalue(arg_type, stable_ivalue, TORCH_ABI_VERSION));
|
||||
}
|
||||
|
||||
op.callBoxed(ivalue_stack);
|
||||
|
||||
// there should then be num_returns IValues on the stack, which
|
||||
// we will convert to StableIValue and repopulate user input stack
|
||||
for (const auto idx : c10::irange(num_returns)) {
|
||||
const auto stack_idx = num_returns - idx - 1;
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
stack[stack_idx] = from_ivalue(
|
||||
ret_type, torch::jit::pop(ivalue_stack), TORCH_ABI_VERSION);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Schema Adapter Infrastructure
|
||||
// SchemaAdapterRegistry contains the adapters registered via
|
||||
// register_schema_adapter that define how to convert the StableIValue argument
|
||||
// stack to an IValue stack when changes are made to the schema of an ATen
|
||||
// function. This should only be relevant in the context of calling
|
||||
// torch_call_dispatcher.
|
||||
|
||||
// Currently this only adapts the argument stack.
|
||||
// C++ default argument resolution will happen at compile time in the
|
||||
// torch/csrc/stable/ops.h header, so extensions always pass complete argument
|
||||
// lists for the version they build against's schema. As such, this is only
|
||||
// needed if a new argument is added to the schema
|
||||
//
|
||||
// This is not declared in the stable shim.h,
|
||||
// so we **do not make any guarantees that the signature of this will not
|
||||
// change**. If there is a need to define similar infrastructure for the returns
|
||||
// of an aten function we can update this.
|
||||
|
||||
namespace {
|
||||
using SchemaAdapterFn = std::function<torch::jit::Stack(
|
||||
const c10::FunctionSchema& current_schema,
|
||||
const StableIValue* extension_stack,
|
||||
uint64_t extension_build_version)>;
|
||||
|
||||
// Global registry for schema adapters
|
||||
class SchemaAdapterRegistry {
|
||||
private:
|
||||
std::unordered_map<
|
||||
std::string,
|
||||
std::vector<std::pair<uint64_t, SchemaAdapterFn>>>
|
||||
adapters_;
|
||||
|
||||
public:
|
||||
static SchemaAdapterRegistry& instance() {
|
||||
static SchemaAdapterRegistry registry;
|
||||
return registry;
|
||||
}
|
||||
|
||||
void register_adapter(
|
||||
const std::string& op_name,
|
||||
uint64_t
|
||||
applies_to_versions_below, // versions below this need the adapter
|
||||
SchemaAdapterFn adapter) {
|
||||
adapters_[op_name].emplace_back(applies_to_versions_below, adapter);
|
||||
// Sort by version ascending - this allows us to find the first (most
|
||||
// specific) match
|
||||
std::sort(
|
||||
adapters_[op_name].begin(),
|
||||
adapters_[op_name].end(),
|
||||
[](const auto& a, const auto& b) { return a.first < b.first; });
|
||||
}
|
||||
|
||||
std::optional<SchemaAdapterFn> get_adapter(
|
||||
const std::string& op_name,
|
||||
uint64_t extension_version) {
|
||||
auto it = adapters_.find(op_name);
|
||||
if (it == adapters_.end())
|
||||
return std::nullopt;
|
||||
|
||||
// Find the first adapter that applies (most specific due to ascending sort)
|
||||
for (const auto& [applies_to_versions_below, adapter] : it->second) {
|
||||
if (extension_version < applies_to_versions_below) {
|
||||
return adapter;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
|
||||
// Internal API for registering adapters that define how to convert the
|
||||
// StableIValue **argument** stack to an IValue stack when changes are
|
||||
// made to the schema of a function. adapter_fn will be used if
|
||||
// extension_build_version < applies_to_versions_below.
|
||||
[[maybe_unused]] AOTITorchError register_schema_adapter(
|
||||
const char* op_name,
|
||||
uint64_t applies_to_versions_below,
|
||||
SchemaAdapterFn adapter_fn) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
auto& registry = SchemaAdapterRegistry::instance();
|
||||
registry.register_adapter(
|
||||
std::string(op_name), applies_to_versions_below, std::move(adapter_fn));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Function to register test schema adapters for _test_schema_upgrader
|
||||
// This demonstrates the adapter registration pattern (internal use only)
|
||||
static AOTITorchError _register_adapters() {
|
||||
// ** Schema adapters should be registered here**
|
||||
// Refer to https://github.com/pytorch/pytorch/pull/165284/ for an example.
|
||||
//
|
||||
// if (auto err = register_schema_adapter(
|
||||
// "aten::your_op",
|
||||
// VERSION_FOO, // applies to versions < VERSION_FOO
|
||||
// adapt_v1_to_vfoo)) {
|
||||
// return err;
|
||||
// }
|
||||
return AOTI_TORCH_SUCCESS;
|
||||
}
|
||||
|
||||
// Static initialization to automatically register test adapters
|
||||
static struct AdapterInitializer {
|
||||
AdapterInitializer() {
|
||||
// Register the test adapters when the library loads
|
||||
_register_adapters();
|
||||
}
|
||||
} adapter_initializer;
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
|
||||
const char* opName,
|
||||
const char* overloadName,
|
||||
StableIValue* stack,
|
||||
// version of stable headers used to build the extension: necessary for
|
||||
// applying schema adapters
|
||||
uint64_t extension_build_version) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
const auto op =
|
||||
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
torch::jit::Stack ivalue_stack;
|
||||
auto& registry = SchemaAdapterRegistry::instance();
|
||||
|
||||
// Check if we need an adapter for this operation
|
||||
if (auto adapter = registry.get_adapter(opName, extension_build_version)) {
|
||||
// Use adapter to create IValue stack
|
||||
ivalue_stack = (*adapter)(schema, stack, extension_build_version);
|
||||
} else {
|
||||
// No adapter needed - implementation matches aoti_torch_call_dispatcher
|
||||
ivalue_stack.reserve(std::max(num_arguments, num_returns));
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
auto stable_ivalue = stack[idx];
|
||||
auto arg_type = schema.arguments()[idx].type();
|
||||
torch::jit::push(
|
||||
ivalue_stack,
|
||||
to_ivalue(arg_type, stable_ivalue, extension_build_version));
|
||||
}
|
||||
}
|
||||
|
||||
op.callBoxed(ivalue_stack);
|
||||
|
||||
// there should then be num_returns IValues on the stack, which
|
||||
// we will convert to StableIValue and repopulate user input stack
|
||||
for (const auto idx : c10::irange(num_returns)) {
|
||||
const auto stack_idx = num_returns - idx - 1;
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
stack[stack_idx] = from_ivalue(
|
||||
ret_type, torch::jit::pop(ivalue_stack), extension_build_version);
|
||||
}
|
||||
});
|
||||
}
|
||||
46
torch/csrc/stable/c/shim.h
Normal file
46
torch/csrc/stable/c/shim.h
Normal file
@ -0,0 +1,46 @@
|
||||
#ifndef STABLE_TORCH_SHIM
|
||||
#define STABLE_TORCH_SHIM
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
#include <torch/csrc/stable/version.h>
|
||||
|
||||
// This header defines stable C API extensions for backward/forward
|
||||
// compatibility when calling ATen operations through the dispatcher.
|
||||
//
|
||||
// This is separate from the main AOTI shim to provide versioning capabilities
|
||||
// for schema changes in native ATen functions.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
using StableIValue = uint64_t;
|
||||
|
||||
// Has the same semantic as aoti_torch_call_dispatcher, but takes an
|
||||
// additional argument for the extension build version. This is
|
||||
// needed for backward compatibility when calling native functions via
|
||||
// the dispatcher. The caller should pass in the libtorch version the
|
||||
// extension is building with (NOT target version).
|
||||
AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
|
||||
const char* opName,
|
||||
const char* overloadName,
|
||||
StableIValue* stack,
|
||||
uint64_t extension_build_version);
|
||||
|
||||
// Version-aware variant of aoti_torch_library_impl that takes an
|
||||
// extension_build_version parameter for backward compatibility
|
||||
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t),
|
||||
uint64_t extension_build_version);
|
||||
|
||||
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
#endif // STABLE_TORCH_SHIM
|
||||
@ -4,12 +4,14 @@
|
||||
// code for better UX.
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
|
||||
// Technically, this file doesn't use anything from stableivalue_conversions.h,
|
||||
// but we need to include it here as the contents of stableivalue_conversions.h
|
||||
// used to live here and so we need to expose them for backwards compatibility.
|
||||
#include <torch/csrc/stable/stableivalue_conversions.h>
|
||||
#include <torch/csrc/stable/version.h>
|
||||
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
|
||||
|
||||
@ -81,7 +83,11 @@ class StableLibrary final {
|
||||
StableLibrary& impl(
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
torch_library_impl(lib_, name, fn, TORCH_ABI_VERSION);
|
||||
#else
|
||||
aoti_torch_library_impl(lib_, name, fn);
|
||||
#endif
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
@ -8,6 +8,8 @@
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/csrc/stable/version.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
|
||||
@ -25,8 +27,13 @@ inline torch::stable::Tensor empty_like(const torch::stable::Tensor& self) {
|
||||
torch::stable::detail::from(std::nullopt),
|
||||
torch::stable::detail::from(std::nullopt),
|
||||
torch::stable::detail::from(std::nullopt)};
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::empty_like", "", stack.data(), TORCH_ABI_VERSION));
|
||||
#else
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::empty_like", "", stack.data()));
|
||||
#endif
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
@ -201,8 +208,13 @@ inline torch::stable::Tensor transpose(
|
||||
torch::stable::detail::from(self),
|
||||
torch::stable::detail::from(dim0),
|
||||
torch::stable::detail::from(dim1)};
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::transpose", "int", stack.data(), TORCH_ABI_VERSION));
|
||||
#else
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::transpose", "int", stack.data()));
|
||||
#endif
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
@ -212,8 +224,13 @@ inline torch::stable::Tensor transpose(
|
||||
inline torch::stable::Tensor zero_(torch::stable::Tensor& self) {
|
||||
const auto num_args = 1;
|
||||
std::array<StableIValue, num_args> stack{torch::stable::detail::from(self)};
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::zero_", "", stack.data(), TORCH_ABI_VERSION));
|
||||
#else
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::zero_", "", stack.data()));
|
||||
#endif
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
@ -228,8 +245,13 @@ inline torch::stable::Tensor copy_(
|
||||
torch::stable::detail::from(self),
|
||||
torch::stable::detail::from(src),
|
||||
torch::stable::detail::from(non_blocking.value_or(false))};
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::copy_", "", stack.data(), TORCH_ABI_VERSION));
|
||||
#else
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::copy_", "", stack.data()));
|
||||
#endif
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
@ -240,9 +262,20 @@ inline torch::stable::Tensor clone(const torch::stable::Tensor& self) {
|
||||
std::array<StableIValue, num_args> stack{
|
||||
torch::stable::detail::from(self),
|
||||
torch::stable::detail::from(std::nullopt)};
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::clone", "", stack.data(), TORCH_ABI_VERSION));
|
||||
#else
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::clone", "", stack.data()));
|
||||
#endif
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
|
||||
// New ops should be added here if they use a brand new shim API
|
||||
|
||||
#endif
|
||||
|
||||
HIDDEN_NAMESPACE_END(torch, stable)
|
||||
|
||||
@ -24,12 +24,17 @@ T to(StableIValue val);
|
||||
// =============================================================================
|
||||
// =============================================================================
|
||||
// FROM CONVERSIONS (T -> StableIValue)
|
||||
// =============================================================================
|
||||
// ======================================================================
|
||||
|
||||
// Specialization for general copyable types (catch-all) => StableIValue
|
||||
template <typename T>
|
||||
struct FromImpl {
|
||||
static StableIValue call(T val) {
|
||||
static StableIValue call(
|
||||
T val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
@ -68,7 +73,12 @@ struct FromImpl {
|
||||
using torch::headeronly::ScalarType;
|
||||
template <>
|
||||
struct FromImpl<ScalarType> {
|
||||
static StableIValue call(ScalarType val) {
|
||||
static StableIValue call(
|
||||
ScalarType val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
switch (val) {
|
||||
case ScalarType::Byte:
|
||||
return from(aoti_torch_dtype_uint8());
|
||||
@ -121,7 +131,12 @@ struct FromImpl<ScalarType> {
|
||||
// Specialization for std::nullopt_t => StableIValue
|
||||
template <>
|
||||
struct FromImpl<std::nullopt_t> {
|
||||
static StableIValue call(std::nullopt_t val) {
|
||||
static StableIValue call(
|
||||
std::nullopt_t val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
return from(nullptr);
|
||||
}
|
||||
};
|
||||
@ -157,11 +172,15 @@ struct FromImpl<std::nullopt_t> {
|
||||
// std::optional<T> or a std::nullopt.
|
||||
template <typename T>
|
||||
struct FromImpl<std::optional<T>> {
|
||||
static StableIValue call(const std::optional<T>& val) {
|
||||
static StableIValue call(
|
||||
const std::optional<T>& val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
if (!val.has_value()) {
|
||||
return from(std::nullopt);
|
||||
}
|
||||
return from(new StableIValue(from(val.value())));
|
||||
return from(new StableIValue(detail::FromImpl<T>::call(
|
||||
val.value(), extension_build_version, is_internal)));
|
||||
}
|
||||
};
|
||||
|
||||
@ -169,7 +188,12 @@ struct FromImpl<std::optional<T>> {
|
||||
// Returns a new owning reference of the underlying Tensor.
|
||||
template <>
|
||||
struct FromImpl<torch::stable::Tensor> {
|
||||
static StableIValue call(const torch::stable::Tensor& val) {
|
||||
static StableIValue call(
|
||||
const torch::stable::Tensor& val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
AtenTensorHandle new_ath;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath));
|
||||
return from(new_ath);
|
||||
@ -183,7 +207,12 @@ struct FromImpl<torch::stable::Tensor> {
|
||||
// Specialization for StableIValue => general copyable types (catch-all)
|
||||
template <typename T>
|
||||
struct ToImpl {
|
||||
static T call(StableIValue val) {
|
||||
static T call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
static_assert(std::is_trivially_copyable_v<T>);
|
||||
// T may not have a default constructor. (For example, it might be
|
||||
// c10::Device.) However, std::memcpy implicitly creates a T at the
|
||||
@ -218,7 +247,12 @@ struct ToImpl {
|
||||
// Specialization for StableIValue => torch::headeronly::ScalarType
|
||||
template <>
|
||||
struct ToImpl<ScalarType> {
|
||||
static ScalarType call(StableIValue val) {
|
||||
static ScalarType call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
int32_t shim_scalartype = to<int32_t>(val);
|
||||
if (shim_scalartype == aoti_torch_dtype_uint8()) {
|
||||
return ScalarType::Byte;
|
||||
@ -273,7 +307,12 @@ struct ToImpl<ScalarType> {
|
||||
// Specialization for StableIValue => std::nullopt_t
|
||||
template <>
|
||||
struct ToImpl<std::nullopt_t> {
|
||||
static std::nullopt_t call(StableIValue val) {
|
||||
static std::nullopt_t call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
// val should be equivalent to from(nullptr)
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -284,14 +323,18 @@ struct ToImpl<std::nullopt_t> {
|
||||
// from IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
|
||||
template <typename T>
|
||||
struct ToImpl<std::optional<T>> {
|
||||
static std::optional<T> call(StableIValue val) {
|
||||
static std::optional<T> call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
auto sivp = to<StableIValue*>(val);
|
||||
|
||||
// sivp is either nullptr or a pointer to a StableIValue
|
||||
if (sivp == nullptr) {
|
||||
return {};
|
||||
}
|
||||
auto inner_val = to<T>(*sivp);
|
||||
auto inner_val =
|
||||
detail::ToImpl<T>::call(*sivp, extension_build_version, is_internal);
|
||||
|
||||
// free the memory associated with StableIValue* sivp
|
||||
delete sivp;
|
||||
@ -305,7 +348,12 @@ struct ToImpl<std::optional<T>> {
|
||||
// underlying AtenTensorHandle.
|
||||
template <>
|
||||
struct ToImpl<torch::stable::Tensor> {
|
||||
static torch::stable::Tensor call(StableIValue val) {
|
||||
static torch::stable::Tensor call(
|
||||
StableIValue val,
|
||||
uint64_t extension_build_version,
|
||||
bool is_internal) {
|
||||
(void)extension_build_version; // Unused parameter
|
||||
(void)is_internal; // Unused parameter
|
||||
return torch::stable::Tensor(to<AtenTensorHandle>(val));
|
||||
}
|
||||
};
|
||||
@ -315,25 +363,60 @@ struct ToImpl<torch::stable::Tensor> {
|
||||
// =============================================================================
|
||||
|
||||
// Expose the partially templated class functions through single functions
|
||||
// The non-private versions will be used by the extension or headers that
|
||||
// the extension includes.
|
||||
template <typename T>
|
||||
inline StableIValue from(T val) {
|
||||
return detail::FromImpl<T>::call(val);
|
||||
return detail::FromImpl<T>::call(
|
||||
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline StableIValue from(const std::optional<T>& val) {
|
||||
return detail::FromImpl<std::optional<T>>::call(val);
|
||||
return detail::FromImpl<std::optional<T>>::call(
|
||||
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||
}
|
||||
|
||||
// The below overload is used! See https://godbolt.org/z/859cshxrW
|
||||
// We are suppressing the warning for versions clang12- and gcc11-
|
||||
[[maybe_unused]] inline StableIValue from(const torch::stable::Tensor& val) {
|
||||
return detail::FromImpl<torch::stable::Tensor>::call(val);
|
||||
return detail::FromImpl<torch::stable::Tensor>::call(
|
||||
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T to(StableIValue val) {
|
||||
return detail::ToImpl<T>::call(val);
|
||||
return detail::ToImpl<T>::call(
|
||||
val, aoti_torch_abi_version(), /*is_internal=*/false);
|
||||
}
|
||||
|
||||
// Internal conversion functions used by from_ivalue and to_ivalue.
|
||||
// These are used in libtorch
|
||||
template <typename T>
|
||||
inline StableIValue _from(T val, uint64_t extension_build_version) {
|
||||
return detail::FromImpl<T>::call(
|
||||
val, extension_build_version, /*is_internal=*/true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline StableIValue _from(
|
||||
const std::optional<T>& val,
|
||||
uint64_t extension_build_version) {
|
||||
return detail::FromImpl<std::optional<T>>::call(
|
||||
val, extension_build_version, /*is_internal=*/true);
|
||||
}
|
||||
|
||||
[[maybe_unused]] inline StableIValue _from(
|
||||
const torch::stable::Tensor& val,
|
||||
uint64_t extension_build_version) {
|
||||
return detail::FromImpl<torch::stable::Tensor>::call(
|
||||
val, extension_build_version, /*is_internal=*/true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T _to(StableIValue val, uint64_t extension_build_version) {
|
||||
return detail::ToImpl<T>::call(
|
||||
val, extension_build_version, /*is_internal=*/true);
|
||||
}
|
||||
|
||||
HIDDEN_NAMESPACE_END(torch, stable, detail)
|
||||
|
||||
29
torch/csrc/stable/version.h
Normal file
29
torch/csrc/stable/version.h
Normal file
@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/version.h>
|
||||
|
||||
// Stable ABI Version Targeting
|
||||
//
|
||||
// This header provides version targeting capabilities for the PyTorch Stable
|
||||
// ABI. Users can define TORCH_TARGET_VERSION to target a specific stable ABI
|
||||
// version instead of using the current TORCH_ABI_VERSION of libtorch at
|
||||
// compile time.
|
||||
//
|
||||
// Usage:
|
||||
// Default behavior (uses current ABI version):
|
||||
// #include <torch/csrc/stable/library.h>
|
||||
//
|
||||
// Target a specific stable version (major.minor) (e.g. PyTorch 2.9):
|
||||
// (1) Pass a compiler flag -DTORCH_TARGET_VERSION=0x0209000000000000
|
||||
// (2) Alternatively, define TORCH_TARGET_VERSION in the source code before
|
||||
// including any header files:
|
||||
// #define TORCH_TARGET_VERSION (((0ULL + 2) << 56) | ((0ULL + 9) << 48))
|
||||
// #include <torch/csrc/stable/library.h>
|
||||
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
#define TORCH_FEATURE_VERSION TORCH_TARGET_VERSION
|
||||
#else
|
||||
#define TORCH_FEATURE_VERSION TORCH_ABI_VERSION
|
||||
#endif
|
||||
|
||||
#define TORCH_VERSION_2_10_0 (((0ULL + 2) << 56) | ((0ULL + 10) << 48))
|
||||
@ -671,8 +671,6 @@ class DTensor(torch.Tensor):
|
||||
def __metadata_guard__(
|
||||
cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool]
|
||||
) -> bool:
|
||||
# TODO - delete this - This is now unused after the PR -
|
||||
# https://github.com/pytorch/pytorch/pull/165824
|
||||
orig_spec, orig_requires_grad = orig
|
||||
other_spec, other_requires_grad = other
|
||||
return (
|
||||
|
||||
@ -6612,13 +6612,13 @@ class ShapeEnv:
|
||||
desc = "Could not guard on data-dependent expression"
|
||||
size_oblivious_result_msg = (
|
||||
"consider using data-dependent friendly APIs such as "
|
||||
"guard_or_false, guard_or_true and statically_known_true"
|
||||
"guard_or_false, guard_or_true and statically_known_true."
|
||||
)
|
||||
|
||||
msg = (
|
||||
f"{desc} {expr} (unhinted: {unhinted_expr}). "
|
||||
f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n"
|
||||
f"{size_oblivious_result_msg}"
|
||||
f"{size_oblivious_result_msg}\n"
|
||||
f"Caused by: {sloc}\n"
|
||||
'For more information, run with TORCH_LOGS="dynamic"\n'
|
||||
"For extended logs when we create symbols, also add "
|
||||
|
||||
@ -19,8 +19,8 @@
|
||||
/// Indicates the ABI version of LibTorch as a single uint64.
|
||||
/// [ byte ][ byte ][ byte ][ byte ][ byte ][ byte ][ byte ][ byte ]
|
||||
/// [ MAJ ][ MIN ][ PATCH][ ABI TAG ]
|
||||
#define TORCH_ABI_VERSION \
|
||||
(uint64_t)TORCH_VERSION_MAJOR << 56 | \
|
||||
(uint64_t)TORCH_VERSION_MINOR << 48 | \
|
||||
(uint64_t)TORCH_VERSION_PATCH << 40 | \
|
||||
TORCH_VERSION_ABI_TAG << 0
|
||||
#define TORCH_ABI_VERSION ( \
|
||||
((0ULL + TORCH_VERSION_MAJOR) << 56) | \
|
||||
((0ULL + TORCH_VERSION_MINOR) << 48) | \
|
||||
((0ULL + TORCH_VERSION_PATCH) << 40) | \
|
||||
((0ULL + TORCH_VERSION_ABI_TAG) << 0))
|
||||
|
||||
@ -14,11 +14,14 @@ from torch.backends.cuda import (
|
||||
SDPAParams,
|
||||
)
|
||||
|
||||
from .varlen import varlen_attn
|
||||
|
||||
|
||||
__all__: list[str] = [
|
||||
"SDPBackend",
|
||||
"sdpa_kernel",
|
||||
"WARN_FOR_UNFUSED_KERNELS",
|
||||
"varlen_attn",
|
||||
]
|
||||
|
||||
# Note: [SDPA warnings]
|
||||
|
||||
@ -7,7 +7,7 @@ that calls into the optimized Flash Attention kernels.
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
from typing import NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -33,7 +33,8 @@ class AuxRequest(NamedTuple):
|
||||
lse: bool = False
|
||||
|
||||
|
||||
@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
|
||||
# import failures when I try to register as custom op
|
||||
# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
|
||||
def _varlen_attn(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -43,7 +44,7 @@ def _varlen_attn(
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Private custom op for variable-length attention.
|
||||
|
||||
@ -69,7 +70,7 @@ def _varlen_attn(
|
||||
False, # return_debug_mask
|
||||
)
|
||||
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
|
||||
output, softmax_lse, rng_state = result[0], result[1], result[6]
|
||||
output, softmax_lse = result[0], result[1]
|
||||
else:
|
||||
log.info("Using Flash Attention backend for varlen_attn")
|
||||
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
|
||||
@ -85,13 +86,10 @@ def _varlen_attn(
|
||||
return_debug_mask=False,
|
||||
)
|
||||
|
||||
rng_state_ = torch.zeros(
|
||||
(2,), dtype=torch.uint64, device=query.device
|
||||
) # hardcoded since dropout is hardcoded to 0
|
||||
return output, softmax_lse, rng_state_
|
||||
return output, softmax_lse
|
||||
|
||||
|
||||
@_varlen_attn.register_fake
|
||||
# @_varlen_attn.register_fake
|
||||
def _varlen_attn_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -101,7 +99,7 @@ def _varlen_attn_fake(
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fake implementation for meta tensor computation and tracing.
|
||||
|
||||
@ -119,9 +117,7 @@ def _varlen_attn_fake(
|
||||
(num_heads, total_q), dtype=torch.float, device=query.device
|
||||
)
|
||||
|
||||
rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device)
|
||||
|
||||
return output, logsumexp, rng_state
|
||||
return output, logsumexp
|
||||
|
||||
|
||||
def varlen_attn(
|
||||
@ -195,145 +191,9 @@ def varlen_attn(
|
||||
... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
|
||||
... )
|
||||
"""
|
||||
out, lse, _ = torch.ops.torch_attn._varlen_attn(
|
||||
out, lse = _varlen_attn(
|
||||
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
|
||||
)
|
||||
if return_aux is not None and return_aux.lse:
|
||||
return out, lse
|
||||
return out
|
||||
|
||||
|
||||
def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
|
||||
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs
|
||||
out, lse, rng_state = output
|
||||
ctx.query = query
|
||||
ctx.key = key
|
||||
ctx.value = value
|
||||
ctx.cu_seq_q = cu_seq_q
|
||||
ctx.cu_seq_k = cu_seq_k
|
||||
ctx.max_q = max_q
|
||||
ctx.max_k = max_k
|
||||
ctx.is_causal = is_causal
|
||||
ctx.output = out
|
||||
ctx.lse = lse
|
||||
ctx.rng_state = rng_state
|
||||
|
||||
|
||||
@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={})
|
||||
def _varlen_attn_backward(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor,
|
||||
cu_seq_k: torch.Tensor,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
unused = torch.empty(0, device=query.device)
|
||||
|
||||
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
|
||||
if use_cudnn:
|
||||
log.info("Using cuDNN backend for varlen_attn")
|
||||
dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
is_causal,
|
||||
rng_state,
|
||||
unused,
|
||||
)
|
||||
else:
|
||||
log.info("Using Flash Attention backend for varlen_attn")
|
||||
dq, dk, dv = torch.ops.aten._flash_attention_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
is_causal,
|
||||
rng_state,
|
||||
unused,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
@_varlen_attn_backward.register_fake
|
||||
def _varlen_attn_backward_fake(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor,
|
||||
cu_seq_k: torch.Tensor,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Fake implementation for meta tensor computation and tracing.
|
||||
"""
|
||||
|
||||
grad_query = torch.empty_like(query)
|
||||
grad_key = torch.empty_like(key)
|
||||
grad_value = torch.empty_like(value)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _backward(
|
||||
ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
query = ctx.query
|
||||
key = ctx.key
|
||||
value = ctx.value
|
||||
cu_seq_q = ctx.cu_seq_q
|
||||
cu_seq_k = ctx.cu_seq_k
|
||||
max_q = ctx.max_q
|
||||
max_k = ctx.max_k
|
||||
is_causal = ctx.is_causal
|
||||
out = ctx.output
|
||||
lse = ctx.lse
|
||||
rng_state = ctx.rng_state
|
||||
|
||||
# rng_state = torch.empty(2, device=query.device)
|
||||
|
||||
dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
is_causal,
|
||||
rng_state,
|
||||
)
|
||||
return dq, dk, dv, None, None, None, None, None, None
|
||||
|
||||
|
||||
_varlen_attn.register_autograd(_backward, setup_context=_setup_context)
|
||||
|
||||
@ -265,10 +265,8 @@ def _private_register_pytree_node(
|
||||
)
|
||||
|
||||
|
||||
def _is_pytreespec_instance(
|
||||
obj: Any, /
|
||||
) -> TypeIs[Union[TreeSpec, python_pytree.TreeSpec]]:
|
||||
return isinstance(obj, (TreeSpec, python_pytree.TreeSpec))
|
||||
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]:
|
||||
return isinstance(obj, TreeSpec)
|
||||
|
||||
|
||||
def treespec_leaf() -> TreeSpec:
|
||||
@ -974,7 +972,7 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||
"""Serialize a treespec to a JSON string."""
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"Expected `treespec` to be instance of "
|
||||
f"treespec_dumps(treespec): Expected `treespec` to be instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ import functools
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
import warnings
|
||||
@ -37,11 +36,10 @@ from typing import (
|
||||
Optional,
|
||||
overload,
|
||||
Protocol,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import deprecated, NamedTuple, Self, TypeIs
|
||||
from typing_extensions import deprecated, NamedTuple, Self
|
||||
|
||||
from torch.torch_version import TorchVersion as _TorchVersion
|
||||
|
||||
@ -1338,39 +1336,6 @@ def treespec_dict(
|
||||
return TreeSpec(dict, list(dct.keys()), list(dct.values()))
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch.utils._cxx_pytree as cxx
|
||||
|
||||
|
||||
def _is_pytreespec_instance(obj: Any) -> TypeIs[Union[TreeSpec, "cxx.TreeSpec"]]:
|
||||
if isinstance(obj, TreeSpec):
|
||||
return True
|
||||
if "torch.utils._cxx_pytree" in sys.modules:
|
||||
# The C++ pytree module is not always available, so we check if it is loaded.
|
||||
# If the C++ pytree module is loaded, we can check if the treespec
|
||||
# is an instance of the C++ TreeSpec class.
|
||||
from torch.utils._cxx_pytree import TreeSpec as CxxTreeSpec
|
||||
|
||||
if isinstance(obj, CxxTreeSpec):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _ensure_python_treespec_instance(
|
||||
treespec: Union[TreeSpec, "cxx.TreeSpec"],
|
||||
) -> TreeSpec:
|
||||
if isinstance(treespec, TreeSpec):
|
||||
return treespec
|
||||
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
f"Expected `treespec` to be an instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
)
|
||||
dummy_tree = treespec.unflatten([0] * treespec.num_leaves)
|
||||
return tree_structure(dummy_tree)
|
||||
|
||||
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
@ -1401,10 +1366,10 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
"""Given a list of values and a TreeSpec, builds a pytree.
|
||||
This is the inverse operation of `tree_flatten`.
|
||||
"""
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise TypeError(
|
||||
f"Expected `treespec` to be an instance of "
|
||||
f"PyTreeSpec but got item of type {type(treespec)}."
|
||||
f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
|
||||
f"instance of TreeSpec but got item of type {type(treespec)}.",
|
||||
)
|
||||
return treespec.unflatten(leaves)
|
||||
|
||||
@ -1835,30 +1800,34 @@ def _broadcast_to_and_flatten(
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[list[Any]]:
|
||||
def broadcast_prefix(
|
||||
prefix_tree: PyTree,
|
||||
full_tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> list[Any]:
|
||||
result: list[Any] = []
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise AssertionError("treespec must be a TreeSpec")
|
||||
|
||||
def add_leaves(x: Any, subtree: PyTree) -> None:
|
||||
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
|
||||
result.extend([x] * subtreespec.num_leaves)
|
||||
|
||||
tree_map_(
|
||||
add_leaves,
|
||||
prefix_tree,
|
||||
full_tree,
|
||||
is_leaf=is_leaf,
|
||||
)
|
||||
return result
|
||||
|
||||
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
try:
|
||||
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
|
||||
except ValueError:
|
||||
if tree_is_leaf(tree, is_leaf=is_leaf):
|
||||
return [tree] * treespec.num_leaves
|
||||
if treespec.is_leaf():
|
||||
return None
|
||||
node_type = _get_node_type(tree)
|
||||
if node_type != treespec.type:
|
||||
return None
|
||||
|
||||
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
|
||||
child_pytrees, context = flatten_fn(tree)
|
||||
|
||||
# Check if the Node is different from the spec
|
||||
if len(child_pytrees) != treespec.num_children or context != treespec._context:
|
||||
return None
|
||||
|
||||
# Recursively flatten the children
|
||||
result: list[Any] = []
|
||||
for child, child_spec in zip(child_pytrees, treespec._children):
|
||||
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
|
||||
if flat is not None:
|
||||
result += flat
|
||||
else:
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -1972,7 +1941,11 @@ _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
|
||||
|
||||
|
||||
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||
treespec = _ensure_python_treespec_instance(treespec)
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise TypeError(
|
||||
f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
|
||||
f"TreeSpec but got item of type {type(treespec)}.",
|
||||
)
|
||||
|
||||
if protocol is None:
|
||||
protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
|
||||
|
||||
Reference in New Issue
Block a user